Add eqFns to calls.h and use for Find

This commit is contained in:
Marshall Lochbaum 2023-02-20 08:44:12 -05:00
parent b56e547e34
commit 89e6be10e5
3 changed files with 16 additions and 14 deletions

View File

@ -206,15 +206,14 @@ B fne_c2(B t, B w, B x) {
extern B eq_c2(B, B, B);
extern B drop_c2(B, B, B);
extern B slash_c1(B, B);
extern B rt_find;
B find_c2(B t, B w, B x) {
ur wr = isAtm(w) ? 0 : RNK(w);
ur xr = isAtm(x) ? 0 : RNK(x);
if (wr > xr) thrF("⍷: Rank of 𝕨 must be at most rank of 𝕩 (%i≡=𝕨, %i≡=𝕩)", wr, xr);
u8 xe;
if (xr==1 && (xe=TI(x,elType))!=el_B && xe!=el_bit && (isAtm(w) || TI(w,elType)!=el_B)) {
u8 xe, we;
if (xr==1 && (xe=TI(x,elType))!=el_B && xe!=el_bit && (isAtm(w) || (we=TI(w,elType))!=el_B)) {
if (wr == 0) return C2(eq, w, x);
usz wl = IA(w);
usz xl = IA(x);
@ -237,20 +236,19 @@ B find_c2(B t, B w, B x) {
usz s = bit_sum(rp, rl);
if (s == 0) break;
// Switch to verifying matches individually
if (s < rl/32 && rl <= I32_MAX) {
if (s < rl/16 && rl <= I32_MAX && we != el_bit) {
B ind = C1(slash, incG(r));
if (TI(ind,elType)!=el_i32) ind = taga(cpyI32Arr(ind));
usz ni = IA(ind);
i32* ip = i32any_ptr(ind);
B ws = C2(drop, m_f64(i), incG(w));
BSS2A slice = TI(x,slice);
u8* wp = (u8*)tyany_ptr(w) + i*elWidth(we);
usz eq_idx = EQFN_INDEX(we, xe);
EqFn equalp = eqFns[eq_idx]; u8 ed = eqFnData[eq_idx];
for (usz ii = 0; ii < ni; ii++) {
usz j = ip[ii];
B sl = taga(arr_shVec(slice(incG(x), i+j, wl-i)));
if (!equal(ws, sl)) bitp_set(rp, j, 0);
decG(sl);
if (!equalp(wp, xp + (i+j)*xw, wl-i, ed)) bitp_set(rp, j, 0);
}
decG(ind); decG(ws);
decG(ind);
break;
}
}

View File

@ -404,9 +404,6 @@ NOINLINE bool atomEqualF(B w, B x) {
// Functions in eqFns compare segments for matching
// data argument comes from eqFnData
typedef bool (*EqFn)(void* a, void* b, u64 l, u64 data);
bool notEq(void* a, void* b, u64 l, u64 data) { return false; }
static const u8 n = 99;
u8 eqFnData[] = { // for the main diagonal, amount to shift length by; otherwise, whether to swap arguments
0,0,0,0,0,n,n,n,
@ -462,6 +459,7 @@ u8 eqFnData[] = { // for the main diagonal, amount to shift length by; otherwise
#undef DEF_EQ_I
#undef DEF_EQ
#endif
bool notEq(void* a, void* b, u64 l, u64 data) { return false; }
EqFn eqFns[] = {
F(1_1), F(1_8), F(1_16), F(1_32), F(1_f64), notEq, notEq, notEq,
F(1_8), F(8_8), F(s8_16), F(s8_32), F(s8_f64), notEq, notEq, notEq,
@ -496,7 +494,7 @@ NOINLINE bool equal(B w, B x) { // doesn't consume
u8 xe = TI(x,elType);
if (we<=el_c32 && xe<=el_c32) { // remove & pass a(w) and a(x) to fn so it can do basic loop
u64 idx = we*8 + xe;
usz idx = EQFN_INDEX(we, xe);
return eqFns[idx](tyany_ptr(w), tyany_ptr(x), ia, eqFnData[idx]);
}
return equalSlow(w, x, ia);

View File

@ -35,4 +35,10 @@ CMP_DEF(le, AS);
#define CMP_AA_IMM(FN, ELT, WHERE, WP, XP, LEN) CMP_AA_CALL(CMP_AA_FN(FN, ELT), WHERE, WP, XP, LEN)
#define CMP_AS_IMM(FN, ELT, WHERE, WP, X, LEN) CMP_AS_CALL(CMP_AS_FN(FN, ELT), WHERE, WP, X, LEN)
// Check if the l elements starting at a and b match
typedef bool (*EqFn)(void* a, void* b, u64 l, u64 data);
extern EqFn eqFns[];
extern u8 eqFnData[];
#define EQFN_INDEX(W_ELT, X_ELT) ((W_ELT)*8 + (X_ELT))
void bit_negatePtr(u64* rp, u64* xp, usz count); // count is number of u64-s