widen inds outside select-cells loops

This commit is contained in:
dzaima 2024-07-26 22:16:52 +03:00
parent fce7567349
commit 69ca524251
3 changed files with 71 additions and 12 deletions

View File

@ -557,10 +557,12 @@ B select_cells_base(B inds, B x0, ux csz, ux cam);
#define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes
B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool shouldBoundsCheck) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊z; xe cannot be el_bit or el_B, unless csz==1; ie must be ≤el_i8 if csz≤128
u8 inds_buf[INDS_BUF_MAX]; (void)inds_buf;
assert(csz!=0 && cam!=0);
assert(csz*cam == IA(x));
assert(ie<=el_i32);
u8 inds_buf[INDS_BUF_MAX]; (void)inds_buf;
bool generic_allowed = true; // whether required interpretation of x hasn't changed from its real one
if (csz==1) { // TODO maybe move to select_rows_B and require csz>=2 here?
i64 bounds[2];
if (!getRange_fns[ie](inds, bounds, indn) || bounds[0]<-1 || bounds[1]>0) goto generic; // could put under shouldBoundsCheck but ideally things setting that to false should handle size-1 cells themselves
@ -575,13 +577,14 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should
u8 lb = arrTypeWidthLog(TY(x));
ux xbump = csz<<lb;
ux rbump = indn<<lb;
ux ria = indn * cam;
u8* xp = tyany_ptr(x);
bool fast; (void) fast;
i64 bounds[2];
if (ie==el_bit) {
if (csz>32 || indn>32 || indn>INDS_BUF_MAX) { // TODO properly tune
u8* rp = m_tyarrv_same(&r, indn * cam, x);
u8* rp = m_tyarrv_same(&r, ria, x);
for (ux i = 0; i < cam; i++) {
bitselFns[lb](rp, inds, loadu_u64(xp), loadu_u64(xp + (1<<lb)), indn);
xp+= xbump;
@ -603,6 +606,7 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should
}
#if SINGELI
assert(INDS_BUF_MAX_COPY == INDS_BUF_MAX);
{
fast = ie==el_i8;
@ -621,7 +625,23 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should
}
skip_bounds_check:;
u8* rp = m_tyarrv_same(&r, indn * cam, x);
#if SINGELI_AVX2 || SINGELI_NEON
if (fast) {
generic_allowed = false;
ux sh = select_rows_widen[lb](inds, inds_buf, indn); // TODO null element in table for guaranteed-zero
if (sh!=0) {
SELECT_ROWS_PRINTF("widening indices by factor of %d:\n", 1<<sh);
SELECT_ROWS_PRINTF(" src: lb=%d, ie=%d, csz=%zu, indn=%zu\n", lb, ie, csz, indn);
inds = inds_buf;
lb-= sh;
csz<<= sh;
indn<<= sh;
SELECT_ROWS_PRINTF(" dst: lb=%d, ie=%d, csz=%zu, indn=%zu\n", lb, ie, csz, indn);
}
}
#endif
u8* rp = m_tyarrv_same(&r, ria, x);
ux slow_cam = cam;
#if SINGELI_AVX2 || SINGELI_NEON
@ -695,6 +715,7 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should
#endif
generic:;
assert(generic_allowed);
B indo = taga(arr_shVec(m_tyslice(inds, a(emptyIVec()), ie, indn)));
r = select_cells_base(indo, x, csz, cam);
return r;

View File

@ -1,4 +1,3 @@
def __shl{(u16)}{a:T, b} = T~~(re_el{u16,a}<<b) # for x86's lack of u8 shift
def broadcast{[(n*2)]E, x:[n]E} = pair{x, x}
def pow2_up{v, least} = max{least, 1<<ceil_log2{v}} # least ⌈ ⌈⌾(2⊸⋆⁼) v
@ -80,14 +79,25 @@ def blend_halves{mode, E, nt, ni} = tup{nt, ni, loader{{TG} => {
# }
}}}
def raw_widen_inds{[k]D, x:[k0]S if k0>=k} = { # : [k*sc]S
def sc = width{D} / width{S}
def add = make{[k*sc]S, range{k*sc} % sc}
if (hasarch{'AVX2'} and [k]D == [4]u64 and S==u32) {
(sel{[8]u32, undefPromote{[8]u32, x}, make{[8]u32, range{8}>>1}}<<sc) + add
} else {
def wd = widen{[k]D, x}
re_el{S, wd * [k]D**base{1<<width{S}, sc**sc}} + add
}
}
def raw_widen_inds{k, sc, x:[_](u8)} = raw_widen_inds{[k]primtype{'u', 8<<sc}, x}
def widen_inds{mode, E, nt0, ni0, sc} = match(lut_gen{mode, primtype{'u',width{E}/sc}, nt0*sc, ni0*sc}) { # e.g. sc==2: {a,b,c,d}[w,x,y,z] → {a0,a1, b0,b1, c0,c1, d0,d1}[w*2,w*2+1, x*2,x*2+1, y*2,y*2+1, z*2zw*2+1]
{{nt1, ni1, G}} => tup{nt1/sc, ni1/sc, loader{{TG} => {
def prev = G{TG}
def ni = ni1/sc
def WV = [ni]primtype{'u', 8*sc}
{is:([ni]u8)} => {
def isw = widen{WV, is} * WV**base{256, sc**sc} + WV**base{256, range{sc}}
each{re_el{E,.}, prev{re_el{u8, isw}}}
each{re_el{E,.}, prev{raw_widen_inds{WV, is}}}
}
}}}
{x} => x
@ -168,10 +178,13 @@ def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=16*4 and ni<=16} =
{is:([16]u8)} => tup{sel{lut, is}}
}}}}
def lut_gen{mode, E, nt, ni if (E==u16 or E==u64) and mode=='i'} = zip_halves{mode, E, nt, ni}
def lut_gen{mode, E, nt, ni if (E==u16 or E==u64) and mode=='c'} = widen_inds{mode, E, nt, max{ni,16}, 2}
def lut_gen{mode, E, nt, ni if E==u32 and hasarch{'AARCH64'}} = zip_halves{mode, E, nt, ni}
# def lut_gen{mode, E, nt, ni if E==u32 and hasarch{'AARCH64'}} = widen_inds{mode, E, nt, ni, 2}
def lut_gen{mode, E, nt, ni if E==u64 and hasarch{'AARCH64'}} = widen_inds{mode, E, nt, ni, 2}
def lut_gen{mode, E, nt, ni if mode=='i' and hasarch{'AVX2'} and (E==u16 or E==u64)} = zip_halves{mode, E, nt, ni}
# def lut_gen{mode, E, nt, ni if mode=='c' and hasarch{'AVX2'} and E==u16} = widen_inds{mode, E, nt, max{ni,16}, 2}
# def lut_gen{mode, E, nt, ni if mode=='c' and hasarch{'AVX2'} and E==u64} = widen_inds{mode, E, nt, ni, 2}
# def lut_gen{mode, E, nt, ni if mode=='c' and E==u64} = zip_halves{mode, E, nt, ni} # widen_inds{mode, E, nt, max{ni,16}, 2}
def lut_gen{mode, E==u64, nt, ni if nt>16 and hasarch{'AVX2'}} = 0
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and (E==u16 or E==u32)} = zip_halves{mode, E, nt, ni}
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and E==u64} = widen_inds{mode, E, nt, ni, 2}

View File

@ -49,12 +49,14 @@ def masked_multistore{r0, vs, M, end} = { # returns bumped-forwards r
def vptr{VT=[_]E, ptr} = tup{VT, *E~~ptr}
fn wrap_inds{TI if issigned{TI}}(src:*void, dst:*void, n:u64, cyc0:u64) : void = {
def cyc = cast_i{TI,cyc0}
if (has_simd) {
def bulk = arch_defvw / width{TI}
def VT = [bulk]TI
@maskedLoop{bulk}(src in tup{VT, *TI~~src}, dst in tup{VT, *TI~~dst} over n) {
@maskedLoop{bulk}(src in vptr{VT, src}, dst in vptr{VT, dst} over n) {
dst = homBlend{src, src + VT**cyc, src < VT**0}
}
} else {
@ -65,6 +67,8 @@ exportT{'si_wrap_inds', each{wrap_inds, tup{i8}}}
def inds_buf_max = 64
export{'INDS_BUF_MAX_COPY', ux~~inds_buf_max}
(if (has_sel) {
fn select_rows_fn{TD, nt, ni, G}(inds:*u8, x0:*void, xbump:u64, r0:*void, rbump:u64, r1:*void) : ux = G{inds, *TD~~x0, xbump, *TD~~r0, rbump, r1} # TG,nt,ni args just for prettier names for debugging
@ -150,6 +154,27 @@ exportT{'si_wrap_inds', each{wrap_inds, tup{i8}}}
def exportP{T, n, vs} = { a:*T = vs; export{n, a} }
exportP{u8, 'select_rows_max_indn', each{{row} => if (length{row}==0) 0 else oneVal{ each{select{.,1}, row}}, select_rows_parts}}
exportP{u8, 'select_rows_min_logcsz', each{{row} => if (length{row}==0) 0 else lb{fold{min, each{select{.,0}, row}}}, select_rows_parts}}
def select_rows_better = scan{{p,{v,i}} => if (length{v}==0) p else i, 0, each{tup, select_rows_parts, range{4}}}
exportP{u8, 'select_rows_better', select_rows_better+1}
fn select_rows_widen{sh}(src:*void, dst:*void, n:ux) : ux = {
if (sh != 0) {
def bulk = (arch_defvw/8) >> sh
def WV = [bulk<<sh]u8
if ((n<<sh) > inds_buf_max) return{0}
@for_backwards (i to inds_buf_max/(width{WV}/8)) {
def s = load{*[bulk]u8~~src, i}
def v = raw_widen_inds{bulk, sh, s}
store{*WV~~dst, i, v}
}
}
sh
}
exportT{'select_rows_widen', each{{t0, t1} => {
def {S, D} = each{select{tup{u8,u16,u32,u64},.}, tup{t0, t1}}
select_rows_widen{t1-t0}
}, select_rows_better, range{4}}}
})