Use a C comparison function instead of a BQN one

This commit is contained in:
Marshall Lochbaum 2023-02-19 21:08:21 -05:00
parent d2950a8df6
commit ea6c7d9a7a

View File

@ -208,14 +208,14 @@ B fne_c2(B t, B w, B x) {
extern B eq_c2(B, B, B);
extern B take_c2(B, B, B);
extern B drop_c2(B, B, B);
extern B and_c2(B, B, B);
extern B slash_c1(B, B);
extern B rt_find;
B find_c2(B t, B w, B x) {
ur wr = isAtm(w) ? 0 : RNK(w);
ur xr = isAtm(x) ? 0 : RNK(x);
if (wr > xr) thrF("⍷: Rank of 𝕨 must be at most rank of 𝕩 (%i≡=𝕨, %i≡=𝕩)", wr, xr);
if (xr==1 && TI(x,elType)!=el_B && (isAtm(w) || TI(w,elType)!=el_B)) {
u8 xe;
if (xr==1 && (xe=TI(x,elType))!=el_B && xe!=el_bit && (isAtm(w) || TI(w,elType)!=el_B)) {
if (wr == 0) return C2(eq, w, x);
usz wl = IA(w);
usz xl = IA(x);
@ -223,18 +223,23 @@ B find_c2(B t, B w, B x) {
if (wl == 0) { decG(w); decG(x); return taga(arr_shVec(allOnes(xl+1))); }
// Compare elements of w to slices of x
SGetU(w)
usz rl = xl - wl + 1; B rt = m_f64(rl); // Result length
B e = C2(eq, GetU(w,0), C2(take, rt, incG(x)));
usz rl = xl - wl + 1; // Result length
u8* xp = tyany_ptr(x);
u64* rp; B r = m_bitarrv(&rp, rl);
CmpASFn eq = CMP_AS_FN(eq, xe);
CMP_AS_CALL(eq, rp, xp, GetU(w,0), rl);
if (wl == 1) goto dec_ret;
usz xw = elWidth(xe);
usz rb = BIT_N(rl);
TALLOC(u64, eq_res, rb);
for (usz i = 1; i < wl; i++) {
B slice = C2(take, rt, C2(drop, m_f64(i), incG(x)));
e = C2(and, e, C2(eq, GetU(w,i), slice));
assert(TI(e,elType) == el_bit);
ux* ep = bitarr_ptr(e);
usz s = bit_sum(ep, rl);
CMP_AS_CALL(eq, eq_res, xp + i*xw, GetU(w,i), rl);
for (usz b = 0; b < rb; b++) rp[b] &= eq_res[b];
usz s = bit_sum(rp, rl);
if (s == 0) break;
// Switch to verifying matches individually
if (s < rl/256 && rl <= I32_MAX) {
B ind = C1(slash, incG(e));
B ind = C1(slash, incG(r));
if (TI(ind,elType)!=el_i32) ind = taga(cpyI32Arr(ind));
usz ni = IA(ind);
i32* ip = i32any_ptr(ind);
@ -242,15 +247,16 @@ B find_c2(B t, B w, B x) {
for (usz ii = 0; ii < ni; ii++) {
usz j = ip[ii];
B slice = C2(take, m_f64(wl-i), C2(drop, m_f64(i+j), incG(x)));
if (!equal(ws, slice)) bitp_set(ep, j, 0);
if (!equal(ws, slice)) bitp_set(rp, j, 0);
decG(slice);
}
decG(ind); decG(ws);
break;
}
}
decG(x); decG(w);
return e;
TFREE(eq_res);
dec_ret:;
decG(x); decG(w); return r;
}
return c2rt(find, w, x);
}