diff --git a/src/builtins/select.c b/src/builtins/select.c index 723e669f..0ff9f5e9 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -546,7 +546,18 @@ Arr* customizeShape(B x); // from cells.c B select_cells_base(B inds, B x0, ux csz, ux cam); +#define CLZC(X) (64-(CLZ((u64)(X)))) + +#ifdef SELECT_ROWS_PRINTF + #undef SELECT_ROWS_PRINTF + #define SELECT_ROWS_PRINTF(...) printf(__VA_ARGS__) +#else + #define SELECT_ROWS_PRINTF(...) +#endif + +#define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool shouldBoundsCheck) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊z; xe cannot be el_bit or el_B, unless csz==1; ie must be ≤el_i8 if csz≤128 + u8 inds_buf[INDS_BUF_MAX]; (void)inds_buf; assert(csz!=0 && cam!=0); assert(csz*cam == IA(x)); assert(ie<=el_i32); @@ -582,9 +593,85 @@ B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool should #if SINGELI { + bool fast = ie==el_i8; (void) fast; + + // TODO under shouldBoundsCheck (and probably rename that) + i64 bounds[2]; + if (!getRange_fns[ie](inds, bounds, indn)) goto generic; + if (bounds[1] >= (i64)csz) goto generic; + if (bounds[0] < 0) { + if (bounds[0] < -(i64)csz) goto generic; + if (csz < 128 && indn < INDS_BUF_MAX) { + assert(ie == el_i8); + si_wrap_inds[ie-el_i8](inds, inds_buf, indn, csz); + inds = inds_buf; + } else { + fast = false; + } + } + u8* rp = m_tyarrv_same(&r, indn * cam, x); ux slow_cam = cam; + #if SINGELI_AVX2 || SINGELI_NEON + ux lnt = CLZC(csz-1); // ceil-log2 of number of elements in table + if (fast && lnt < select_rows_tab_h) { + u8 max_indn = select_rows_max_indn[lb]; + if (indn > max_indn) goto no_fast; + u8 min_lnt = select_rows_min_logcsz[lb]; + ux used_lnt; + SELECT_ROWS_PRINTF("csz: %zu/%d; inds: %d/%d\n", csz, 1< max_indn) rep--; // simd_repeat_inds over-estimates + SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid inds\n", rep, indn, indn*rep, csz, csz*rep); + + used_lnt = min_lnt; + ux fine_csz = 1ULL<<(min_lnt+1); // TODO have a proper per-element-type LUT of "target LUT size" + if (csz < fine_csz) { + ux cap = fine_csz / csz; + if (rep > cap) rep = cap; + } else rep = 1; + + ux new_lnt = CLZC(csz*rep-1); + if (new_lnt > used_lnt) used_lnt = new_lnt; + + SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid table\n", rep, indn, indn*rep, csz, csz*rep); + inds = inds_buf; + + } else { + rep = 1; + used_lnt = lnt { + def r = tup{nt, ni, select_rows_fn{TD, nt, ni, {inds, x0, xbump, r0, rbump, r1} => { + # lprintf{'inds=', [ni]u8, ' tab=', [nt]TD, ' xbump=', xbump, ' rbump=', rbump, ' allowed=', (*u8~~r1) - *u8~~r0} + def iv = load{*[ni]u8~~inds, 0} + x:*TD = x0 + r:*TD = r0 + iters:ux = 0 + while (*void~~(r + ni) <= r1) { + def rs = G{x}{iv} + masked_multistore{r, rs, maskNone, '!'} + x+= xbump + r+= rbump + ++iters + } + iters + }}} + merge{tup{r}, try_nt{nt+1}} + } + {..._} => tup{} + } + + try_nt{2} + } + def select_rows_parts = each{select_rows, tup{u8, u16, u32, u64}} + def max_nt = fold{max, each{select{.,0}, join{select_rows_parts}}} + def max_ni = fold{max, each{select{.,1}, join{select_rows_parts}}} + def select_rows_tab_h = lb{max_nt}+1 + + # >{t ← (¯1+⌊16÷𝕩)⌾(¯1⊸⊑) ⌊𝕩÷˜↕16 ⋄ ⟨𝕩, ⌊16÷𝕩, 1+¯1⊑t, t⟩}¨ 2↓↕8 + # >{t←(¯1+𝕩-𝕩|16)⌾(¯1⊸⊑) 𝕩|↕16 ⋄ ⟨𝕩, 𝕩 - 1+¯1⊑t, t⟩}¨ 2↓↕8 + repeat_tab:*u8 = join{each{{k} => merge{ + merge{range{15}%k, k - 16%k - 1}, # TODO top not used, can clean up + merge{(range{15}/k)>>0, ((16/k)>>0) - 1} + }, 1+range{8}}} + fn simd_repeat_inds(src:*u8, dst:*u8, start:u8, csz:u8) : ux = { # src and dst may be the same + assert{(start>=1) & (start<=16)} + assert{start < max_ni} + def V = [16]u8 + def VU16 = re_el{u16, V} + def px{x} = promote{ux, x} + + shufr:u8 = 1 # number of repeats produced by shuffle path + v:= load{*V~~src, 0} + if (start <= 8) { + tab:= *V~~(repeat_tab + 32 * (start-1)) + def l0 = load{tab, 0} + def l1 = load{tab, 1} + v = sel{V, v, l0} + v+= V~~(VU16~~l1 * broadcast{VU16, csz}) + shufr = extract{l1, 15}+1 + } + + def shufe = px{shufr} * px{start} # number of elements produced by shuffle path + + cdst:= dst + r:ux = 0 + do { + r+= px{shufr} + store{*V~~cdst, 0, v} + v+= broadcast{V, csz * shufr} + cdst+= shufe + } while (max_ni>16 and cdst+start <= dst + max_ni) + # lprintf{shufr, r, ux~~(cdst - dst) / px{start}} + # lprintf{load{*[max_ni]u8~~dst}, cdst - dst} + r + } + export{'simd_repeat_inds', simd_repeat_inds} + + def null_fn = select_rows_fn{void, 'BAD', 0, {..._} => { emit{void,'fatal','"bad select_rows"'}; 0 }} + export{'null_fn', null_fn} + export{'select_rows_tab_h', ux~~select_rows_tab_h} + exportT{'select_rows_tab', join{flip{each{{row} => { + def a = each{select{.,0}, row} + 0 # +0 to work around findmatches bug + def b = 1<select{select{row,i},2}; {{}}=>null_fn }, findmatches{a, b}} + }, select_rows_parts}}}} + + def exportP{T, n, vs} = { a:*T = vs; export{n, a} } + exportP{u8, 'select_rows_max_indn', each{{row} => oneVal{ each{select{.,1}, row}}, select_rows_parts}} + exportP{u8, 'select_rows_min_logcsz', each{{row} => lb{fold{min, each{select{.,0}, row}}}, select_rows_parts}} +}) + + + fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { # TODO don't require SIMD? w:= *TI ~~ w0 x:= *TD ~~ x0