fast inds⊸⊏˘ mat
This commit is contained in:
parent
0d7bf86182
commit
478c389c4b
@ -546,7 +546,18 @@ Arr* customizeShape(B x); // from cells.c
|
||||
|
||||
B select_cells_base(B inds, B x0, ux csz, ux cam);
|
||||
|
||||
#define CLZC(X) (64-(CLZ((u64)(X))))
|
||||
|
||||
#ifdef SELECT_ROWS_PRINTF
|
||||
#undef SELECT_ROWS_PRINTF
|
||||
#define SELECT_ROWS_PRINTF(...) printf(__VA_ARGS__)
|
||||
#else
|
||||
#define SELECT_ROWS_PRINTF(...)
|
||||
#endif
|
||||
|
||||
#define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes
|
||||
B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool shouldBoundsCheck) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊z; xe cannot be el_bit or el_B, unless csz==1; ie must be ≤el_i8 if csz≤128
|
||||
u8 inds_buf[INDS_BUF_MAX]; (void)inds_buf;
|
||||
assert(csz!=0 && cam!=0);
|
||||
assert(csz*cam == IA(x));
|
||||
assert(ie<=el_i32);
|
||||
@ -582,9 +593,85 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should
|
||||
|
||||
#if SINGELI
|
||||
{
|
||||
bool fast = ie==el_i8; (void) fast;
|
||||
|
||||
// TODO under shouldBoundsCheck (and probably rename that)
|
||||
i64 bounds[2];
|
||||
if (!getRange_fns[ie](inds, bounds, indn)) goto generic;
|
||||
if (bounds[1] >= (i64)csz) goto generic;
|
||||
if (bounds[0] < 0) {
|
||||
if (bounds[0] < -(i64)csz) goto generic;
|
||||
if (csz < 128 && indn < INDS_BUF_MAX) {
|
||||
assert(ie == el_i8);
|
||||
si_wrap_inds[ie-el_i8](inds, inds_buf, indn, csz);
|
||||
inds = inds_buf;
|
||||
} else {
|
||||
fast = false;
|
||||
}
|
||||
}
|
||||
|
||||
u8* rp = m_tyarrv_same(&r, indn * cam, x);
|
||||
|
||||
ux slow_cam = cam;
|
||||
#if SINGELI_AVX2 || SINGELI_NEON
|
||||
ux lnt = CLZC(csz-1); // ceil-log2 of number of elements in table
|
||||
if (fast && lnt < select_rows_tab_h) {
|
||||
u8 max_indn = select_rows_max_indn[lb];
|
||||
if (indn > max_indn) goto no_fast;
|
||||
u8 min_lnt = select_rows_min_logcsz[lb];
|
||||
ux used_lnt;
|
||||
SELECT_ROWS_PRINTF("csz: %zu/%d; inds: %d/%d\n", csz, 1<<min_lnt, (int)indn, max_indn);
|
||||
|
||||
ux indn_real = indn;
|
||||
ux rep;
|
||||
if (indn*2 <= max_indn) {
|
||||
assert(max_indn<=32); // otherwise inds_buf hard-coded size may need to change
|
||||
rep = simd_repeat_inds(inds, inds_buf, indn, csz);
|
||||
indn_real = rep*indn;
|
||||
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - raw repeat\n", rep, indn, indn*rep, csz, csz*rep);
|
||||
PLAINLOOP while (rep*indn > max_indn) rep--; // simd_repeat_inds over-estimates
|
||||
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid inds\n", rep, indn, indn*rep, csz, csz*rep);
|
||||
|
||||
used_lnt = min_lnt;
|
||||
ux fine_csz = 1ULL<<(min_lnt+1); // TODO have a proper per-element-type LUT of "target LUT size"
|
||||
if (csz < fine_csz) {
|
||||
ux cap = fine_csz / csz;
|
||||
if (rep > cap) rep = cap;
|
||||
} else rep = 1;
|
||||
|
||||
ux new_lnt = CLZC(csz*rep-1);
|
||||
if (new_lnt > used_lnt) used_lnt = new_lnt;
|
||||
|
||||
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid table\n", rep, indn, indn*rep, csz, csz*rep);
|
||||
inds = inds_buf;
|
||||
|
||||
} else {
|
||||
rep = 1;
|
||||
used_lnt = lnt<min_lnt? min_lnt : lnt;
|
||||
}
|
||||
|
||||
assert(indn*rep <= max_indn);
|
||||
AUTO fn = select_rows_tab[used_lnt*4 + lb];
|
||||
if (fn == null_fn) goto no_fast;
|
||||
ux done = fn(inds, xp, csz*rep, rp, indn*rep, rp + cam*rbump) * rep;
|
||||
ux left = cam - done;
|
||||
SELECT_ROWS_PRINTF("done_rows: %zu; left_rows: %zu; left_els: %zu; left_max: %zu\n", done, left, indn*left, indn_real);
|
||||
if (left) {
|
||||
xp+= done * xbump;
|
||||
rp+= done * rbump;
|
||||
if (left*csz <= 127) {
|
||||
assert(indn*left <= indn_real);
|
||||
bool ok = SIMD_SELECT(ie, lb+3)(inds, xp, rp, indn*left, I64_MAX); assert(ok);
|
||||
} else {
|
||||
slow_cam = left;
|
||||
goto no_fast;
|
||||
}
|
||||
}
|
||||
|
||||
goto decG_ret;
|
||||
}
|
||||
no_fast:;
|
||||
#endif
|
||||
|
||||
SimdSelectFn fn = SIMD_SELECT(ie, lb+3);
|
||||
for (ux i = 0; i < slow_cam; i++) {
|
||||
|
||||
@ -49,6 +49,111 @@ def masked_multistore{r0, vs, M, end} = { # returns bumped-forwards r
|
||||
|
||||
|
||||
|
||||
fn wrap_inds{TI if issigned{TI}}(src:*void, dst:*void, n:u64, cyc0:u64) : void = {
|
||||
def cyc = cast_i{TI,cyc0}
|
||||
if (has_simd) {
|
||||
def bulk = arch_defvw / width{TI}
|
||||
def VT = [bulk]TI
|
||||
@maskedLoop{bulk}(src in tup{VT, *TI~~src}, dst in tup{VT, *TI~~dst} over n) {
|
||||
dst = homBlend{src, src + VT**cyc, src < VT**0}
|
||||
}
|
||||
} else {
|
||||
@for (src in *TI~~src, dst in *TI~~dst over n) dst = tern{src<0, src+cyc, src}
|
||||
}
|
||||
}
|
||||
exportT{'si_wrap_inds', each{wrap_inds, tup{i8}}}
|
||||
|
||||
|
||||
|
||||
(if (has_sel) {
|
||||
fn select_rows_fn{TD, nt, ni, G}(inds:*u8, x0:*void, xbump:u64, r0:*void, rbump:u64, r1:*void) : ux = G{inds, *TD~~x0, xbump, *TD~~r0, rbump, r1} # TG,nt,ni args just for prettier names for debugging
|
||||
|
||||
def select_rows{TD} = {
|
||||
def try_nt{nt} = match(lut_gen{'c', TD, nt, 2}) {
|
||||
{{nt, ni, G}} => {
|
||||
def r = tup{nt, ni, select_rows_fn{TD, nt, ni, {inds, x0, xbump, r0, rbump, r1} => {
|
||||
# lprintf{'inds=', [ni]u8, ' tab=', [nt]TD, ' xbump=', xbump, ' rbump=', rbump, ' allowed=', (*u8~~r1) - *u8~~r0}
|
||||
def iv = load{*[ni]u8~~inds, 0}
|
||||
x:*TD = x0
|
||||
r:*TD = r0
|
||||
iters:ux = 0
|
||||
while (*void~~(r + ni) <= r1) {
|
||||
def rs = G{x}{iv}
|
||||
masked_multistore{r, rs, maskNone, '!'}
|
||||
x+= xbump
|
||||
r+= rbump
|
||||
++iters
|
||||
}
|
||||
iters
|
||||
}}}
|
||||
merge{tup{r}, try_nt{nt+1}}
|
||||
}
|
||||
{..._} => tup{}
|
||||
}
|
||||
|
||||
try_nt{2}
|
||||
}
|
||||
def select_rows_parts = each{select_rows, tup{u8, u16, u32, u64}}
|
||||
def max_nt = fold{max, each{select{.,0}, join{select_rows_parts}}}
|
||||
def max_ni = fold{max, each{select{.,1}, join{select_rows_parts}}}
|
||||
def select_rows_tab_h = lb{max_nt}+1
|
||||
|
||||
# >{t ← (¯1+⌊16÷𝕩)⌾(¯1⊸⊑) ⌊𝕩÷˜↕16 ⋄ ⟨𝕩, ⌊16÷𝕩, 1+¯1⊑t, t⟩}¨ 2↓↕8
|
||||
# >{t←(¯1+𝕩-𝕩|16)⌾(¯1⊸⊑) 𝕩|↕16 ⋄ ⟨𝕩, 𝕩 - 1+¯1⊑t, t⟩}¨ 2↓↕8
|
||||
repeat_tab:*u8 = join{each{{k} => merge{
|
||||
merge{range{15}%k, k - 16%k - 1}, # TODO top not used, can clean up
|
||||
merge{(range{15}/k)>>0, ((16/k)>>0) - 1}
|
||||
}, 1+range{8}}}
|
||||
fn simd_repeat_inds(src:*u8, dst:*u8, start:u8, csz:u8) : ux = { # src and dst may be the same
|
||||
assert{(start>=1) & (start<=16)}
|
||||
assert{start < max_ni}
|
||||
def V = [16]u8
|
||||
def VU16 = re_el{u16, V}
|
||||
def px{x} = promote{ux, x}
|
||||
|
||||
shufr:u8 = 1 # number of repeats produced by shuffle path
|
||||
v:= load{*V~~src, 0}
|
||||
if (start <= 8) {
|
||||
tab:= *V~~(repeat_tab + 32 * (start-1))
|
||||
def l0 = load{tab, 0}
|
||||
def l1 = load{tab, 1}
|
||||
v = sel{V, v, l0}
|
||||
v+= V~~(VU16~~l1 * broadcast{VU16, csz})
|
||||
shufr = extract{l1, 15}+1
|
||||
}
|
||||
|
||||
def shufe = px{shufr} * px{start} # number of elements produced by shuffle path
|
||||
|
||||
cdst:= dst
|
||||
r:ux = 0
|
||||
do {
|
||||
r+= px{shufr}
|
||||
store{*V~~cdst, 0, v}
|
||||
v+= broadcast{V, csz * shufr}
|
||||
cdst+= shufe
|
||||
} while (max_ni>16 and cdst+start <= dst + max_ni)
|
||||
# lprintf{shufr, r, ux~~(cdst - dst) / px{start}}
|
||||
# lprintf{load{*[max_ni]u8~~dst}, cdst - dst}
|
||||
r
|
||||
}
|
||||
export{'simd_repeat_inds', simd_repeat_inds}
|
||||
|
||||
def null_fn = select_rows_fn{void, 'BAD', 0, {..._} => { emit{void,'fatal','"bad select_rows"'}; 0 }}
|
||||
export{'null_fn', null_fn}
|
||||
export{'select_rows_tab_h', ux~~select_rows_tab_h}
|
||||
exportT{'select_rows_tab', join{flip{each{{row} => {
|
||||
def a = each{select{.,0}, row} + 0 # +0 to work around findmatches bug
|
||||
def b = 1<<range{select_rows_tab_h}
|
||||
each{match { {{i}}=>select{select{row,i},2}; {{}}=>null_fn }, findmatches{a, b}}
|
||||
}, select_rows_parts}}}}
|
||||
|
||||
def exportP{T, n, vs} = { a:*T = vs; export{n, a} }
|
||||
exportP{u8, 'select_rows_max_indn', each{{row} => oneVal{ each{select{.,1}, row}}, select_rows_parts}}
|
||||
exportP{u8, 'select_rows_min_logcsz', each{{row} => lb{fold{min, each{select{.,0}, row}}}, select_rows_parts}}
|
||||
})
|
||||
|
||||
|
||||
|
||||
fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { # TODO don't require SIMD?
|
||||
w:= *TI ~~ w0
|
||||
x:= *TD ~~ x0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user