diff --git a/src/builtins/cells.c b/src/builtins/cells.c index 8cac5b71..5817f967 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -711,7 +711,11 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr case n_rtack: dec(w); return x; case n_ltack: return const_cells(x, xk, xsh, w, chr); case n_select: - if (isArr(w) && RNK(w)==1 && xcr==1 && TI(w,arrD1)) { // TODO handle RNK(w)!=1 + if (isArr(w) && RNK(w)==1 && xcr==1) { // TODO handle RNK(w)!=1 + if (!TI(w,arrD1)) { + w = num_squeezeChk(w); + if (!TI(w,arrD1)) break; + } assert(xr > 1); ux wia = IA(w); ShArr* rsh = m_shArr(xr); diff --git a/src/builtins/select.c b/src/builtins/select.c index 7de512a7..8fc908da 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -559,16 +559,19 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz cam, usz csz) { // consu #undef FREE_CHECK } -static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same element type as src, but new ia - u8 se = TI(src,elType); assert(se!=el_bit); +static void* m_arrv_same_t(B* r, usz ia, u8 ty) { + u8 se = TIi(ty,elType); if (se==el_B) { HArr_p p = m_harr0v(ia); *r = p.b; return p.a; } else { - return m_tyarrlv(r, arrTypeWidthLog(TY(src)), ia, arrNewType(TY(src))); + return m_tyarrlbv(r, arrTypeBitsLog(ty), ia, arrNewType(ty)); } } +static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same element type as src, but new ia + return m_arrv_same_t(r, ia, TY(src)); +} B slash_c2(B, B, B); Arr* customizeShape(B x); // from cells.c @@ -585,7 +588,7 @@ B select_cells_base(B inds, B x0, ux csz, ux cam); #endif #define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes -B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊x +B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊x; if inds are valid and csz<=128, ie must be <=el_i8 assert(csz!=0 && cam!=0 && indn!=0); assert(csz*cam == IA(x)); assert(ie<=el_i32); @@ -597,30 +600,38 @@ B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ ( if (!getRange_fns[ie](inds, bounds, indn) || bounds[0]<-1 || bounds[1]>0) goto generic_any; return C2(slash, m_f64(indn), taga(arr_shVec(customizeShape(x)))); } - assert(csz>=2); + ux ria = indn * cam; + B r; + u8* xp; u8 xe = TI(x,elType); u8 lb = arrTypeWidthLog(TY(x)); - u8* xp; + if (xe==el_B) { if (sizeof(B) != 8) goto generic_any; xp = (u8*) arr_bptr(x); if (xp == NULL) goto generic_any; } else { - if (xe == el_bit) goto generic_any; xp = tyany_ptr(x); + if (xe == el_bit) { + #if SINGELI_AVX2 || SINGELI_NEON + if (indn<=8 && csz<=8) goto bit_ok; + #endif + goto generic_any; + goto bit_ok; bit_ok:; + } } - B r; - ux ria = indn * cam; bool fast; (void) fast; ux xbump = csz<32 || indn>32 || indn>INDS_BUF_MAX) { // TODO properly tune + assert(xe!=el_bit && (csz>8 || indn>8)); u8* rp = m_arrv_same(&r, ria, x); for (ux i = 0; i < cam; i++) { bitselFns[lb](rp, inds, loadu_u64(xp), loadu_u64(xp + (1<128); #if SINGELI_AVX2 || SINGELI_NEON if (fast) { @@ -680,11 +692,68 @@ B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ ( } #endif + #if SINGELI_AVX2 || SINGELI_NEON + if (xe==el_bit) { + assert(ie==el_i8 && csz<=8 && indn<=8 && csz>=2 && indn>=1); + // TODO si_select_cells_bit_lt64 for indn==1 + static const u8 rep_lut[9] = {0,3,2,1,1,0,0,0,0}; + u8 exp = rep_lut[csz>indn? csz : indn]; + ux rindn = indn<>exp; + + if (rcsz!=8) { + Arr* xa = customizeShape(x); + usz* xsh = arr_shAlloc(xa, 2); + xsh[0] = rcam; + xsh[1] = rcsz; + // leave ia unchanged, desynchronizing from product of shape; TODO really really shouldn't do that, and instead pass cam & csz directly to bit widener + x = widenBitArr(taga(xa), 1); + xp = tyany_ptr(x); + SELECT_ROWS_PRINTF("8bit: widen %zu‿%zu → ⟨%zu,%zu→8⟩\n", cam, csz, rcam, rcsz); + } + + if (exp!=0) { + simd_repeat_inds(inds, inds_buf, indn, csz); + inds = inds_buf; + } + + u64* rp; + ux ria0 = rindn!=8? 8*rcam : ria; + r = m_bitarrv(&rp, ria0); + SELECT_ROWS_PRINTF("8bit: indn=%zu rindn=%zu csz=%zu rcsz=%zu cam=%zu ria0=%zu rcam=%zu\n", indn, rindn, csz, rcsz, cam, ria0, rcam); + si_select_rows_8bit(inds, rindn, xp, rp, (ria0+7)/8); + + if (rindn!=8) { + SELECT_ROWS_PRINTF("8bit: narrow %zu → %zu\n", rcsz, csz); + usz* rsh = arr_shAlloc(a(r), 2); + rsh[0] = rcam; + rsh[1] = 8; + + usz tgt = rindn; + r = narrowWidenedBitArr(r, 1, 1, &tgt); // TODO this assumes trailing zeroes + + Arr* ra = arr_shVec(customizeShape(r)); + r = taga(ra); + + ux ria1 = ra->ia; + assert(ria <= ria1); + FINISH_OVERALLOC(ra, offsetof(TyArr,a) + (ria+7)/8, offsetof(TyArr,a) + (ria1+7)/8); + ra->ia = ria; + } + + goto decG_ret; + } + #endif + u8* rp = m_arrv_same(&r, ria, x); ux slow_cam = cam; #if SINGELI_AVX2 || SINGELI_NEON ux lnt = CLZC(csz-1); // ceil-log2 of number of elements in table + if (fast && lnt < select_rows_tab_h) { u8 max_indn = select_rows_max_indn[lb]; if (indn > max_indn) goto no_fast; @@ -788,7 +857,7 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸ ux in = IA(inds); if (in == 0) return taga(emptyArr(x, 1)); u8 ie = TI(inds,elType); - if (csz<=2? ie!=el_bit : csz<128? ie>el_i8 : !elInt(ie)) { + if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) { inds = num_squeeze(inds); ie = TI(inds,elType); if (!elInt(ie)) goto generic; diff --git a/src/singeli/src/lut.singeli b/src/singeli/src/lut.singeli index b198699e..c76aece4 100644 --- a/src/singeli/src/lut.singeli +++ b/src/singeli/src/lut.singeli @@ -1,7 +1,10 @@ def __shl{(u16)}{a:T, b} = T~~(re_el{u16,a}<>b) def broadcast{[(n*2)]E, x:[n]E} = pair{x, x} def pow2_up{v, least} = max{least, 1< tup{sel{lut, is}} }}}} +# TODO E==u8 nt==128 via tbx4(tbl4,d,i-64) def lut_gen{mode, E, nt, ni if mode=='i' and hasarch{'AVX2'} and (E==u16 or E==u64)} = zip_halves{mode, E, nt, ni} # def lut_gen{mode, E, nt, ni if mode=='c' and hasarch{'AVX2'} and E==u16} = widen_inds{mode, E, nt, max{ni,16}, 2} @@ -187,3 +191,30 @@ def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and E==u64} = widen_inds{mode, def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='c' and E>=u16} = 0 def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='i' and E==u64 and nt>16} = 0 + + +def lut16{tab:([16]u8), idxs:([16]u8)} = sel{[16]u8, tab, idxs} +def lut16{tab:([16]u8), idxs:([32]u8) if hasarch{'X86_64'}} = sel{[16]u8, pair{tab, tab}, idxs} + +def shuf_u8bits{inds:*u8, ni} = 0 +def shuf_u8bits{inds:*u8, ni if has_sel} = { + def I = [16]u8 + v0:= I**0 + v1:= I**0 + + iv:= I**1 + @for (ind in inds over ni) { + c:= (iota{I} & I**(1<<(ind&3))) != I**0 + c&= iv + iv = iv+iv + if (ind>=4) v1|= c + else v0|= c + } + + {x:V=[_](u8)} => { + def __shr{a:(V), 4 if hasarch{'X86_64'}} = (a >>{u16} 4) & V**0x0f + def lo = lut16{v0, x & V**0x0f} + def hi = lut16{v1, x >> 4} + lo | hi + } +} \ No newline at end of file diff --git a/src/singeli/src/select.singeli b/src/singeli/src/select.singeli index 3811c903..a1dce94c 100644 --- a/src/singeli/src/select.singeli +++ b/src/singeli/src/select.singeli @@ -9,8 +9,6 @@ def arch_minvw = if (hasarch{'AARCH64'}) 64 else 128 def arch_minv{T=[_]E if width{T}< arch_minvw} = [arch_minvw / width{E}]E def arch_minv{T if width{T}>=arch_minvw} = T -def has_sel = hasarch{'AVX2'} or hasarch{'AARCH64'} - def gather if_inline (hasarch{'AVX2'}) { # def:T - masked original content @@ -306,6 +304,16 @@ exportT{'si_select_tab', join{table{select_fn, 1 } export{'simd_select_bool128', simd_select_bool128} + + fn si_select_rows_8bit(inds:*u8, indn:ux, src:*void, dst:*void, rows:ux) : void = { # leaves zeroes in result cells above indn + def bulk = arch_defvw / 8 + def V = [bulk]u8 + def lut = shuf_u8bits{inds, indn} + @maskedLoop{bulk}(src in tup{V,*u8~~src}, dst in tup{V,*u8~~dst} over rows) { + dst = lut{src} + } + } + export{'si_select_rows_8bit', si_select_rows_8bit} }) diff --git a/test/cases/cells.bqn b/test/cases/cells.bqn index 668bbe1d..597aec30 100644 --- a/test/cases/cells.bqn +++ b/test/cases/cells.bqn @@ -94,6 +94,14 @@ !"⊏: Indexing out-of-bounds (¯129∊𝕨, 128≡≠𝕩)" % %USE evar ⋄ 10‿¯129 {𝕨⊸⊏˘𝕩}_evar 10‿128⥊1 !"⊏: Indexing out-of-bounds (128∊𝕨, 128≡≠𝕩)" % %USE evar ⋄ 10‿128 {𝕨⊸⊏˘𝕩}_evar 10‿128⥊1 !"⊏: Indexing out-of-bounds (1∊𝕨, 1≡≠𝕩)" % %USE evar ⋄ 1‿0‿0‿1 {𝕨⊸⊏˘𝕩}_evar 10‿1⥊1 +!"⊏: Indexing out-of-bounds (1000∊𝕨, 3≡≠𝕩)" % %USE evar ⋄ (3⥊1000) {𝕨⊸⊏˘𝕩}_evar 100‿3⥊1 +!"⊏: Indexing out-of-bounds (1000∊𝕨, 4≡≠𝕩)" % %USE evar ⋄ (4⥊1000) {𝕨⊸⊏˘𝕩}_evar 100‿4⥊1 +!"⊏: Indexing out-of-bounds (1000∊𝕨, 8≡≠𝕩)" % %USE evar ⋄ (8⥊1000) {𝕨⊸⊏˘𝕩}_evar 100‿8⥊1 +⟨1‿2⟩⊸⊏˘ 10‿8⥊↕100 %% (8×↕10) +⌜ 1‿2 +⟨1‿2⟩⊸⊏˘ 10‿4‿2⥊↕100 %% (8×↕10) +⌜ [2‿3,4‿5] +⟨1‿2,⟨1,0⟩⟩⊸⊏˘ 10‿4‿2⥊↕100 %% (8×↕10) +⌜ [3‿2,5‿4] +[2‿1,4‿5]⊸⊏˘ 10‿8⥊↕100 %% (8×↕10) +⌜ [2‿1,4‿5] +[1‿1,0‿1]⊸⊏˘ 10‿2‿2⥊↕100 %% (4×↕10) +⌜ 2‿2‿2⥊2‿3‿2‿3‿0‿1‿2‿3 ( %USE IS_HEAPVERIFY diff --git a/test/cases/fuzz/select-cells.bqn b/test/cases/fuzz/select-cells.bqn index d65db956..adca1093 100644 --- a/test/cases/fuzz/select-cells.bqn +++ b/test/cases/fuzz/select-cells.bqn @@ -34,3 +34,42 @@ Test {{@ + 𝕩 •rand.Range 1114111}} Test {{⊑⟜"foo"‿@¨ 𝕩 •rand.Range 2}} ) + + +( + %USE var + F ← {𝕨⊸⊏˘ 𝕩} + { + 𝕩+↩ 1 + is ← 𝕨 •rand.Range 𝕩 + n ← •rand.Range 200 + d ← n‿𝕩•rand.Range 2 + is‿n‿𝕩 ! 1=≠⍷ ⟨is F d, is F "Ai8"V d⟩ ∾ {𝕩=2? ⟨("Ai8"V is) F d⟩; ⟨⟩}𝕩 + }⌜˜ ↕10 +) + +( + %USE var + F ← {𝕨⊸⊏˘ 𝕩}⎊'e' + { + 𝕩+↩ 1 + 𝕨+↩ 1 + is ← 𝕨⥊𝕩 + n ← 1+•rand.Range 200 + d ← n‿𝕩•rand.Range 2 + + 𝕨‿𝕩‿n! ∧´ 'e'⊸≡¨ ⟨is F d, is F "Ai8"V d⟩ ∾ {𝕩=2? ⟨("Ai8"V is) F d⟩; ⟨⟩}𝕩 + }⌜˜ ↕10 +) + +( + %USE var + F ← {𝕨⊸⊏˘ 𝕩}⎊'e' + + {𝕊 inds: + csz ← ≠⊑inds + inds { + ⟨csz, 𝕨, ≢𝕩⟩ ! 1=≠⍷ ⥊ (V⟜𝕨¨ (∧´𝕨≤1)◶⟨⋈, "Ab"⊸⋈⟩ "Ai8") F⌜ V⟜𝕩¨ "Ab"‿"Ai8"‿"Ai32" + }⌜ {𝕩‿csz •rand.Range 2}¨ (↕8)∾⥊0‿1+⌜3↓2⋆↕9 + }¨ ⟨⥊↕2‿2⟩ ∾ {⟨↕𝕩, •rand.Range˜𝕩⟩}¨ 1‿3‿4‿5‿8‿9‿16 +) \ No newline at end of file