fast inds⊸⊏˘ mat

This commit is contained in:
dzaima 2024-07-25 17:34:15 +03:00
parent 0d7bf86182
commit 478c389c4b
2 changed files with 192 additions and 0 deletions

View File

@ -546,7 +546,18 @@ Arr* customizeShape(B x); // from cells.c
B select_cells_base(B inds, B x0, ux csz, ux cam);
#define CLZC(X) (64-(CLZ((u64)(X))))
#ifdef SELECT_ROWS_PRINTF
#undef SELECT_ROWS_PRINTF
#define SELECT_ROWS_PRINTF(...) printf(__VA_ARGS__)
#else
#define SELECT_ROWS_PRINTF(...)
#endif
#define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes
B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool shouldBoundsCheck) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊z; xe cannot be el_bit or el_B, unless csz==1; ie must be ≤el_i8 if csz≤128
u8 inds_buf[INDS_BUF_MAX]; (void)inds_buf;
assert(csz!=0 && cam!=0);
assert(csz*cam == IA(x));
assert(ie<=el_i32);
@ -582,9 +593,85 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should
#if SINGELI
{
bool fast = ie==el_i8; (void) fast;
// TODO under shouldBoundsCheck (and probably rename that)
i64 bounds[2];
if (!getRange_fns[ie](inds, bounds, indn)) goto generic;
if (bounds[1] >= (i64)csz) goto generic;
if (bounds[0] < 0) {
if (bounds[0] < -(i64)csz) goto generic;
if (csz < 128 && indn < INDS_BUF_MAX) {
assert(ie == el_i8);
si_wrap_inds[ie-el_i8](inds, inds_buf, indn, csz);
inds = inds_buf;
} else {
fast = false;
}
}
u8* rp = m_tyarrv_same(&r, indn * cam, x);
ux slow_cam = cam;
#if SINGELI_AVX2 || SINGELI_NEON
ux lnt = CLZC(csz-1); // ceil-log2 of number of elements in table
if (fast && lnt < select_rows_tab_h) {
u8 max_indn = select_rows_max_indn[lb];
if (indn > max_indn) goto no_fast;
u8 min_lnt = select_rows_min_logcsz[lb];
ux used_lnt;
SELECT_ROWS_PRINTF("csz: %zu/%d; inds: %d/%d\n", csz, 1<<min_lnt, (int)indn, max_indn);
ux indn_real = indn;
ux rep;
if (indn*2 <= max_indn) {
assert(max_indn<=32); // otherwise inds_buf hard-coded size may need to change
rep = simd_repeat_inds(inds, inds_buf, indn, csz);
indn_real = rep*indn;
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - raw repeat\n", rep, indn, indn*rep, csz, csz*rep);
PLAINLOOP while (rep*indn > max_indn) rep--; // simd_repeat_inds over-estimates
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid inds\n", rep, indn, indn*rep, csz, csz*rep);
used_lnt = min_lnt;
ux fine_csz = 1ULL<<(min_lnt+1); // TODO have a proper per-element-type LUT of "target LUT size"
if (csz < fine_csz) {
ux cap = fine_csz / csz;
if (rep > cap) rep = cap;
} else rep = 1;
ux new_lnt = CLZC(csz*rep-1);
if (new_lnt > used_lnt) used_lnt = new_lnt;
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid table\n", rep, indn, indn*rep, csz, csz*rep);
inds = inds_buf;
} else {
rep = 1;
used_lnt = lnt<min_lnt? min_lnt : lnt;
}
assert(indn*rep <= max_indn);
AUTO fn = select_rows_tab[used_lnt*4 + lb];
if (fn == null_fn) goto no_fast;
ux done = fn(inds, xp, csz*rep, rp, indn*rep, rp + cam*rbump) * rep;
ux left = cam - done;
SELECT_ROWS_PRINTF("done_rows: %zu; left_rows: %zu; left_els: %zu; left_max: %zu\n", done, left, indn*left, indn_real);
if (left) {
xp+= done * xbump;
rp+= done * rbump;
if (left*csz <= 127) {
assert(indn*left <= indn_real);
bool ok = SIMD_SELECT(ie, lb+3)(inds, xp, rp, indn*left, I64_MAX); assert(ok);
} else {
slow_cam = left;
goto no_fast;
}
}
goto decG_ret;
}
no_fast:;
#endif
SimdSelectFn fn = SIMD_SELECT(ie, lb+3);
for (ux i = 0; i < slow_cam; i++) {

View File

@ -49,6 +49,111 @@ def masked_multistore{r0, vs, M, end} = { # returns bumped-forwards r
fn wrap_inds{TI if issigned{TI}}(src:*void, dst:*void, n:u64, cyc0:u64) : void = {
def cyc = cast_i{TI,cyc0}
if (has_simd) {
def bulk = arch_defvw / width{TI}
def VT = [bulk]TI
@maskedLoop{bulk}(src in tup{VT, *TI~~src}, dst in tup{VT, *TI~~dst} over n) {
dst = homBlend{src, src + VT**cyc, src < VT**0}
}
} else {
@for (src in *TI~~src, dst in *TI~~dst over n) dst = tern{src<0, src+cyc, src}
}
}
exportT{'si_wrap_inds', each{wrap_inds, tup{i8}}}
(if (has_sel) {
fn select_rows_fn{TD, nt, ni, G}(inds:*u8, x0:*void, xbump:u64, r0:*void, rbump:u64, r1:*void) : ux = G{inds, *TD~~x0, xbump, *TD~~r0, rbump, r1} # TG,nt,ni args just for prettier names for debugging
def select_rows{TD} = {
def try_nt{nt} = match(lut_gen{'c', TD, nt, 2}) {
{{nt, ni, G}} => {
def r = tup{nt, ni, select_rows_fn{TD, nt, ni, {inds, x0, xbump, r0, rbump, r1} => {
# lprintf{'inds=', [ni]u8, ' tab=', [nt]TD, ' xbump=', xbump, ' rbump=', rbump, ' allowed=', (*u8~~r1) - *u8~~r0}
def iv = load{*[ni]u8~~inds, 0}
x:*TD = x0
r:*TD = r0
iters:ux = 0
while (*void~~(r + ni) <= r1) {
def rs = G{x}{iv}
masked_multistore{r, rs, maskNone, '!'}
x+= xbump
r+= rbump
++iters
}
iters
}}}
merge{tup{r}, try_nt{nt+1}}
}
{..._} => tup{}
}
try_nt{2}
}
def select_rows_parts = each{select_rows, tup{u8, u16, u32, u64}}
def max_nt = fold{max, each{select{.,0}, join{select_rows_parts}}}
def max_ni = fold{max, each{select{.,1}, join{select_rows_parts}}}
def select_rows_tab_h = lb{max_nt}+1
# >{t ← (¯1+⌊16÷𝕩)⌾(¯1⊸⊑) ⌊𝕩÷˜↕16 ⋄ ⟨𝕩, ⌊16÷𝕩, 1+¯1⊑t, t⟩}¨ 2↓↕8
# >{t←(¯1+𝕩-𝕩|16)⌾(¯1⊸⊑) 𝕩|↕16 ⋄ ⟨𝕩, 𝕩 - 1+¯1⊑t, t⟩}¨ 2↓↕8
repeat_tab:*u8 = join{each{{k} => merge{
merge{range{15}%k, k - 16%k - 1}, # TODO top not used, can clean up
merge{(range{15}/k)>>0, ((16/k)>>0) - 1}
}, 1+range{8}}}
fn simd_repeat_inds(src:*u8, dst:*u8, start:u8, csz:u8) : ux = { # src and dst may be the same
assert{(start>=1) & (start<=16)}
assert{start < max_ni}
def V = [16]u8
def VU16 = re_el{u16, V}
def px{x} = promote{ux, x}
shufr:u8 = 1 # number of repeats produced by shuffle path
v:= load{*V~~src, 0}
if (start <= 8) {
tab:= *V~~(repeat_tab + 32 * (start-1))
def l0 = load{tab, 0}
def l1 = load{tab, 1}
v = sel{V, v, l0}
v+= V~~(VU16~~l1 * broadcast{VU16, csz})
shufr = extract{l1, 15}+1
}
def shufe = px{shufr} * px{start} # number of elements produced by shuffle path
cdst:= dst
r:ux = 0
do {
r+= px{shufr}
store{*V~~cdst, 0, v}
v+= broadcast{V, csz * shufr}
cdst+= shufe
} while (max_ni>16 and cdst+start <= dst + max_ni)
# lprintf{shufr, r, ux~~(cdst - dst) / px{start}}
# lprintf{load{*[max_ni]u8~~dst}, cdst - dst}
r
}
export{'simd_repeat_inds', simd_repeat_inds}
def null_fn = select_rows_fn{void, 'BAD', 0, {..._} => { emit{void,'fatal','"bad select_rows"'}; 0 }}
export{'null_fn', null_fn}
export{'select_rows_tab_h', ux~~select_rows_tab_h}
exportT{'select_rows_tab', join{flip{each{{row} => {
def a = each{select{.,0}, row} + 0 # +0 to work around findmatches bug
def b = 1<<range{select_rows_tab_h}
each{match { {{i}}=>select{select{row,i},2}; {{}}=>null_fn }, findmatches{a, b}}
}, select_rows_parts}}}}
def exportP{T, n, vs} = { a:*T = vs; export{n, a} }
exportP{u8, 'select_rows_max_indn', each{{row} => oneVal{ each{select{.,1}, row}}, select_rows_parts}}
exportP{u8, 'select_rows_min_logcsz', each{{row} => lb{fold{min, each{select{.,0}, row}}}, select_rows_parts}}
})
fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { # TODO don't require SIMD?
w:= *TI ~~ w0
x:= *TD ~~ x0