diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 6d369960..365d39cd 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -1135,68 +1135,109 @@ B slash_ucw(B t, B o, B w, B x) { B arg = C2(slash, incG(w), incG(x)); usz argIA = IA(arg); - B rep = c1(o, arg); + B rep = c1(o, arg); // TODO special-case non-callable rep if (isAtm(rep) || RNK(rep)!=1 || IA(rep) != argIA) thrF("𝔽⌾(a⊸/)𝕩: 𝔽 must return an array with the same shape as its input (expected ⟨%s⟩, got %H)", argIA, rep); u8 re = el_or(TI(x,elType), TI(rep,elType)); - MAKE_MUT_INIT(r, ia, re? re : 1); - ux repI = 0; - if (we==el_bit && re!=el_B) { + + if (MAY_F(we==el_bit)) { + bit_w:; + ConvArr r = toEltypeArrX(x, re); + ux ni = 0; u64* d = bitany_ptr(w); - void* rp = r->a; switch (re) { default: UD; - case el_bit: - case el_i8: x = toI8Any(x); rep = toI8Any(rep); goto bit_u8; - case el_c8: x = toC8Any(x); rep = toC8Any(rep); goto bit_u8; - case el_i16: x = toI16Any(x); rep = toI16Any(rep); goto bit_u16; - case el_c16: x = toC16Any(x); rep = toC16Any(rep); goto bit_u16; - case el_i32: x = toI32Any(x); rep = toI32Any(rep); goto bit_u32; - case el_c32: x = toC32Any(x); rep = toC32Any(rep); goto bit_u32; - case el_f64: x = toF64Any(x); rep = toF64Any(rep); goto bit_u64; + case el_i8: rep = toI8Any(rep); goto bit_u8; + case el_c8: rep = toC8Any(rep); goto bit_u8; + case el_i16: rep = toI16Any(rep); goto bit_u16; + case el_c16: rep = toC16Any(rep); goto bit_u16; + case el_i32: rep = toI32Any(rep); goto bit_u32; + case el_c32: rep = toC32Any(rep); goto bit_u32; + case el_f64: rep = toF64Any(rep); goto bit_u64; + case el_bit: { + // TODO BMI2 pext + void* np = bitany_ptr(rep); + NOUNROLL for (usz i = 0; i < ia; i++) { + bool v = bitp_get(d, i); + bitp_set(r.rp, i, bitp_get(v? np : r.xp, v? ni : i)); + ni+= v; + } + goto done_v2; + } + case el_B: { + B* np = arr_bptr(rep); + if (np != NULL) { + if (r.refState == 1) { + NOUNROLL for (usz i = 0; i < ia; i++) { + bool v = bitp_get(d, i); + B xv = ((B*)r.xp)[i]; + if (v) dec(xv); + ((B*)r.rp)[i] = v? inc(np[ni]) : xv; + ni+= v; + } + } else { + NOUNROLL for (usz i = 0; i < ia; i++) { + bool v = bitp_get(d, i); + ((B*)r.rp)[i] = inc(*(v? np+ni : ((B*)r.xp)+i)); + ni+= v; + } + } + } else { + if (r.refState == 1) decByMask(bitany_ptr(w), r.xp, ia, 0); + else incByMask(bitany_ptr(w), r.xp, ia, 1); + SGet(rep) + NOUNROLL for (usz i = 0; i < ia; i++) { + bool v = bitp_get(d, i); + ((B*)r.rp)[i] = v? Get(rep,ni) : ((B*)r.xp)[i]; + ni+= v; + } + } + NOGC_E; + goto done_v2; + } } - #define IMPL(T) do { \ - T* xp = tyany_ptr(x); \ - T* np = tyany_ptr(rep); \ + #define IMPL(T) do { \ + T* np = tyany_ptr(rep); \ NOUNROLL for (usz i = 0; i < ia; i++) { \ - bool v = bitp_get(d, i); \ - ((T*)rp)[i] = *(v? np+repI : xp+i); \ - repI+= v; \ - } \ - goto mut_dec_ret; \ + bool v = bitp_get(d, i); \ + ((T*)r.rp)[i] = *(v? np+ni : ((T*)r.xp)+i); \ + ni+= v; \ + } \ + goto done_v2; \ } while(0) - bit_u8: IMPL(u8); bit_u16: IMPL(u16); bit_u32: IMPL(u32); bit_u64: IMPL(u64); #undef IMPL + done_v2:; + decG(rep); decG(w); decG(x); + return r.res; } else { - SGet(x) SGet(rep) MUTG_INIT(r); - if (we == el_bit) { - u64* d = bitany_ptr(w); - for (usz i = 0; i < ia; i++) mut_setG(r, i, bitp_get(d, i)? Get(rep,repI++) : Get(x,i)); - } else { - if (re == el_B) { - mut_fillG(r, 0, m_f64(0), ia); - NOGC_E; - } - SGetU(w) SGetU(rep) - for (usz i = 0; i < ia; i++) { - ux cw = o2u64G(GetU(w, i)); - if (cw) { - B cr = Get(rep,repI); - if (CHECK_VALID) for (ux j = 1; j < cw; j++) if (!compatible(GetU(rep,repI+j), cr)) { mut_pfree(r,i); thrM("𝔽⌾(a⊸/): Incompatible result elements"); } - mut_setG(r, i, cr); - repI+= cw; - } else mut_setG(r, i, Get(x,i)); + ux wia = IA(w); + SGetU(w) SGetU(rep) + B w2 = C2(ne, incG(w), m_f64(0)); + ux nz = bit_sum(bitany_ptr(w2), wia); + + ux ni = 0; + M_HARR(rp, nz) + for (ux i = 0; i < wia; i++) { + ux c = o2sG(GetU(w, i)); + if (c != 0) { + B cr = GetU(rep, ni); + for (ux j = 1; j < c; j++) if (!compatible(GetU(rep,ni+j), cr)) thrM("𝔽⌾(a⊸/): Incompatible result elements"); + HARR_ADDA(rp, inc(cr)); } + ni+= c; } + assert(ni == IA(rep)); + + decG(w); + w = w2; + decG(rep); + rep = toEltypeArr(HARR_FV(rp), re).obj; + goto bit_w; } - mut_dec_ret:; - decG(w); decG(rep); - B rb = mut_fcd(r, x); - return re==0? taga(cpyBitArr(rb)) : rb; } void slash_init(void) { diff --git a/test/cases/under.bqn b/test/cases/under.bqn index 43c13020..108b9718 100644 --- a/test/cases/under.bqn +++ b/test/cases/under.bqn @@ -202,9 +202,9 @@ n←500 ⋄ a←↕n ⋄ i←(-n)+↕2×n ⋄ r←⌽(2×n)⥊a ⋄ ! (⌽a) ≡ %USE eqvar ⋄ 1‿0‿0‿0‿1‿1 { 3¨⌾(𝕨⊸/)𝕩}_eqvar "hellow" %% 3‿'e'‿'l'‿'l'‿3‿3 %USE eqvar ⋄ 1‿0‿0‿0‿1‿1 { +˙¨⌾(𝕨⊸/)𝕩}_eqvar -‿÷‿×‿=‿<‿> %% +‿÷‿×‿=‿+‿+ !"𝕨/𝕩: Lengths of components of 𝕨 must match 𝕩 (6 ≠ 7)" % 0¨⌾(1‿0‿0‿0‿1‿1⊸/) 0‿1‿0‿1‿0‿1‿0 -4↑ (⋈⋈3)⌾(0‿1‿0⊸/) ↕⋈3 %!PROPER_FILLS %% ⟨⋈0,⋈3,⋈2,0⟩ -4↑ ⊢⌾(0‿1‿0⊸/) ↕⋈3 %!PROPER_FILLS %% ⟨⋈0,⋈1,⋈2,0⟩ -4↑ ⊢⌾(0‿0‿0⊸/) ↕⋈3 %!PROPER_FILLS %% ⟨⋈0,⋈1,⋈2,0⟩ +4↑ (⋈⋈3)⌾(0‿1‿0⊸/) ↕⋈3 %% ⟨⋈0,⋈3,⋈2,⋈0⟩ +4↑ ⊢⌾(0‿1‿0⊸/) ↕⋈3 %% ⟨⋈0,⋈1,⋈2,⋈0⟩ +4↑ ⊢⌾(0‿0‿0⊸/) ↕⋈3 %% ⟨⋈0,⋈1,⋈2,⋈0⟩ !"No fill found" % 4↑ (⋈⋈3)⌾(0‿1‿0⊸/) ↕⋈3 %PROPER_FILLS !"No fill found" % 4↑ ⊢⌾(0‿1‿0⊸/) ↕⋈3 %PROPER_FILLS !"No fill found" % 4↑ ⊢⌾(0‿0‿0⊸/) ↕⋈3 %PROPER_FILLS