diff --git a/src/builtins/fns.c b/src/builtins/fns.c index 3268c6de..78230342 100644 --- a/src/builtins/fns.c +++ b/src/builtins/fns.c @@ -205,8 +205,56 @@ B fne_c2(B t, B w, B x) { } +extern B eq_c2(B, B, B); +extern B slash_c1(B, B); extern B rt_find; B find_c2(B t, B w, B x) { + ur wr = isAtm(w) ? 0 : RNK(w); + ur xr = isAtm(x) ? 0 : RNK(x); + if (wr > xr) thrF("ā·: Rank of š•Ø must be at most rank of š•© (%i≔=š•Ø, %i≔=š•©)", wr, xr); + u8 xe, we; + if (xr==1 && (xe=TI(x,elType))!=el_B && xe!=el_bit && (isAtm(w) || (we=TI(w,elType))!=el_B)) { + if (wr == 0) return C2(eq, w, x); + usz wl = IA(w); + usz xl = IA(x); + B r; + if (wl > xl) { r = emptyIVec(); goto dec_ret; } + if (wl == 0) { r = taga(arr_shVec(allOnes(xl+1))); goto dec_ret; } + // Compare elements of w to slices of x + usz rl = xl - wl + 1; // Result length + u8* xp = tyany_ptr(x); + u64* rp; r = m_bitarrv(&rp, rl); + CmpASFn eq = CMP_AS_FN(eq, xe); + SGetU(w) + CMP_AS_CALL(eq, rp, xp, GetU(w,0), rl); + if (wl == 1) goto dec_ret; + usz xw = elWidth(xe); + usz rb = BIT_N(rl); + TALLOC(u64, eq_res, rb); + for (usz i = 1; i < wl; i++) { + CMP_AS_CALL(eq, eq_res, xp + i*xw, GetU(w,i), rl); + for (usz b = 0; b < rb; b++) rp[b] &= eq_res[b]; + usz s = bit_sum(rp, rl); + if (s == 0) break; + // Switch to verifying matches individually + if (s < rl/16 && rl <= I32_MAX && we != el_bit) { + B ind = toI32Any(C1(slash, incG(r))); + usz ni = IA(ind); + i32* ip = i32any_ptr(ind); + u8* wp = (u8*)tyany_ptr(w) + i*elWidth(we); + EqFnObj eqfn = EQFN_GET(we, xe); + for (usz ii = 0; ii < ni; ii++) { + usz j = ip[ii]; + if (!EQFN_CALL(eqfn, wp, xp + (i+j)*xw, wl-i)) bitp_set(rp, j, 0); + } + decG(ind); + break; + } + } + TFREE(eq_res); + dec_ret:; + decG(x); decG(w); return r; + } return c2rt(find, w, x); } diff --git a/src/builtins/group.c b/src/builtins/group.c index fe0aced4..78e20c01 100644 --- a/src/builtins/group.c +++ b/src/builtins/group.c @@ -134,9 +134,9 @@ static B group_simple(B w, B x, ur xr, usz wia, usz xn, usz* xsh, u8 we) { bitp_set(mp, 0, -1!=o2fG(IGetU(w,0))); B ind = C1(slash, m); - w = C2(select, inc(ind), w); - if (TI(ind,elType)!=el_i32) ind = taga(cpyI32Arr(ind)); - if (TI(w ,elType)!=el_i32) w = taga(cpyI32Arr(w )); + w = C2(select, incG(ind), w); + ind = toI32Any(ind); + w = toI32Any(w); wia = IA(ind); i32* ip = i32any_ptr(ind); @@ -179,7 +179,7 @@ static B group_simple(B w, B x, ur xr, usz wia, usz xn, usz* xsh, u8 we) { x = C2(slash, m, x); xn = *SH(x); neg = 0; } - if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w)); + w = toI32Any(w); i32* wp = i32any_ptr(w); for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0; for (usz i = 0; i < xn; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check diff --git a/src/builtins/md1.c b/src/builtins/md1.c index f75a90ff..bf43fd5f 100644 --- a/src/builtins/md1.c +++ b/src/builtins/md1.c @@ -284,6 +284,36 @@ static NOINLINE B shift_cells(B f, B x, u8 e, u8 rtid) { return mut_fcd(r, x); } +static B allBit(bool b, usz n) { + return taga(arr_shVec(b ? allOnes(n) : allZeroes(n))); +} +static NOINLINE B match_cells(bool ne, B w, B x, ur wr, ur xr, usz len) { + usz* wsh = SH(w); + if (wr != xr || (wr>1 && !eqShPart(wsh+1, SH(x)+1, wr-1))) { + return allBit(ne, len); + } + usz csz = shProd(wsh, 1, wr); + if (csz == 0) return allBit(!ne, len); + u8 we = TI(w,elType); + u8 xe = TI(x,elType); + if (we>el_c32 || xe>el_c32) return bi_N; + usz ww = csz * elWidth(we); u8* wp = tyany_ptr(w); + usz xw = csz * elWidth(xe); u8* xp = tyany_ptr(x); + u64* rp; B r = m_bitarrv(&rp, len); + if (csz == 1 && we == xe) { + CmpAAFn cmp = ne ? CMP_AA_FN(ne,we) : CMP_AA_FN(eq,we); + CMP_AA_CALL(cmp, rp, wp, xp, len); + } else { + if (we==el_bit || xe==el_bit) return bi_N; + EqFnObj eqfn = EQFN_GET(we, xe); + for (usz i = 0; i < len; i++) { + bitp_set(rp, i, ne^EQFN_CALL(eqfn, wp, xp, csz)); + wp += ww; xp += xw; + } + } + return r; +} + B shape_c1(B, B); B cell_c1(Md1D* d, B x) { B f = d->f; if (isAtm(x) || RNK(x)==0) { @@ -397,6 +427,13 @@ B cell_c2(Md1D* d, B w, B x) { B f = d->f; usz cam = SH(w)[0]; if (cam==0) return cell2_empty(f, w, x, wr, xr); if (cam != SH(x)[0]) thrF("˘: Leading axis of arguments not equal (%H ≔ ā‰¢š•Ø, %H ≔ ā‰¢š•©)", w, x); + if (isFun(f)) { + u8 rtid = v(f)->flags-1; + if (rtid==n_feq || rtid==n_fne) { + B r = match_cells(rtid!=n_feq, w, x, wr, xr, cam); + if (!q_N(r)) { decG(w); decG(x); return r; } + } + } S_SLICES(w) S_SLICES(x) M_HARR(r, cam); for (usz i=0,wp=0,xp=0; i