commit
0312c05851
@ -205,8 +205,56 @@ B fne_c2(B t, B w, B x) {
|
||||
}
|
||||
|
||||
|
||||
extern B eq_c2(B, B, B);
|
||||
extern B slash_c1(B, B);
|
||||
extern B rt_find;
|
||||
B find_c2(B t, B w, B x) {
|
||||
ur wr = isAtm(w) ? 0 : RNK(w);
|
||||
ur xr = isAtm(x) ? 0 : RNK(x);
|
||||
if (wr > xr) thrF("⍷: Rank of 𝕨 must be at most rank of 𝕩 (%i≡=𝕨, %i≡=𝕩)", wr, xr);
|
||||
u8 xe, we;
|
||||
if (xr==1 && (xe=TI(x,elType))!=el_B && xe!=el_bit && (isAtm(w) || (we=TI(w,elType))!=el_B)) {
|
||||
if (wr == 0) return C2(eq, w, x);
|
||||
usz wl = IA(w);
|
||||
usz xl = IA(x);
|
||||
B r;
|
||||
if (wl > xl) { r = emptyIVec(); goto dec_ret; }
|
||||
if (wl == 0) { r = taga(arr_shVec(allOnes(xl+1))); goto dec_ret; }
|
||||
// Compare elements of w to slices of x
|
||||
usz rl = xl - wl + 1; // Result length
|
||||
u8* xp = tyany_ptr(x);
|
||||
u64* rp; r = m_bitarrv(&rp, rl);
|
||||
CmpASFn eq = CMP_AS_FN(eq, xe);
|
||||
SGetU(w)
|
||||
CMP_AS_CALL(eq, rp, xp, GetU(w,0), rl);
|
||||
if (wl == 1) goto dec_ret;
|
||||
usz xw = elWidth(xe);
|
||||
usz rb = BIT_N(rl);
|
||||
TALLOC(u64, eq_res, rb);
|
||||
for (usz i = 1; i < wl; i++) {
|
||||
CMP_AS_CALL(eq, eq_res, xp + i*xw, GetU(w,i), rl);
|
||||
for (usz b = 0; b < rb; b++) rp[b] &= eq_res[b];
|
||||
usz s = bit_sum(rp, rl);
|
||||
if (s == 0) break;
|
||||
// Switch to verifying matches individually
|
||||
if (s < rl/16 && rl <= I32_MAX && we != el_bit) {
|
||||
B ind = toI32Any(C1(slash, incG(r)));
|
||||
usz ni = IA(ind);
|
||||
i32* ip = i32any_ptr(ind);
|
||||
u8* wp = (u8*)tyany_ptr(w) + i*elWidth(we);
|
||||
EqFnObj eqfn = EQFN_GET(we, xe);
|
||||
for (usz ii = 0; ii < ni; ii++) {
|
||||
usz j = ip[ii];
|
||||
if (!EQFN_CALL(eqfn, wp, xp + (i+j)*xw, wl-i)) bitp_set(rp, j, 0);
|
||||
}
|
||||
decG(ind);
|
||||
break;
|
||||
}
|
||||
}
|
||||
TFREE(eq_res);
|
||||
dec_ret:;
|
||||
decG(x); decG(w); return r;
|
||||
}
|
||||
return c2rt(find, w, x);
|
||||
}
|
||||
|
||||
|
||||
@ -134,9 +134,9 @@ static B group_simple(B w, B x, ur xr, usz wia, usz xn, usz* xsh, u8 we) {
|
||||
bitp_set(mp, 0, -1!=o2fG(IGetU(w,0)));
|
||||
|
||||
B ind = C1(slash, m);
|
||||
w = C2(select, inc(ind), w);
|
||||
if (TI(ind,elType)!=el_i32) ind = taga(cpyI32Arr(ind));
|
||||
if (TI(w ,elType)!=el_i32) w = taga(cpyI32Arr(w ));
|
||||
w = C2(select, incG(ind), w);
|
||||
ind = toI32Any(ind);
|
||||
w = toI32Any(w);
|
||||
wia = IA(ind);
|
||||
|
||||
i32* ip = i32any_ptr(ind);
|
||||
@ -179,7 +179,7 @@ static B group_simple(B w, B x, ur xr, usz wia, usz xn, usz* xsh, u8 we) {
|
||||
x = C2(slash, m, x); xn = *SH(x);
|
||||
neg = 0;
|
||||
}
|
||||
if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w));
|
||||
w = toI32Any(w);
|
||||
i32* wp = i32any_ptr(w);
|
||||
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
|
||||
for (usz i = 0; i < xn; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check
|
||||
|
||||
@ -284,6 +284,36 @@ static NOINLINE B shift_cells(B f, B x, u8 e, u8 rtid) {
|
||||
return mut_fcd(r, x);
|
||||
}
|
||||
|
||||
static B allBit(bool b, usz n) {
|
||||
return taga(arr_shVec(b ? allOnes(n) : allZeroes(n)));
|
||||
}
|
||||
static NOINLINE B match_cells(bool ne, B w, B x, ur wr, ur xr, usz len) {
|
||||
usz* wsh = SH(w);
|
||||
if (wr != xr || (wr>1 && !eqShPart(wsh+1, SH(x)+1, wr-1))) {
|
||||
return allBit(ne, len);
|
||||
}
|
||||
usz csz = shProd(wsh, 1, wr);
|
||||
if (csz == 0) return allBit(!ne, len);
|
||||
u8 we = TI(w,elType);
|
||||
u8 xe = TI(x,elType);
|
||||
if (we>el_c32 || xe>el_c32) return bi_N;
|
||||
usz ww = csz * elWidth(we); u8* wp = tyany_ptr(w);
|
||||
usz xw = csz * elWidth(xe); u8* xp = tyany_ptr(x);
|
||||
u64* rp; B r = m_bitarrv(&rp, len);
|
||||
if (csz == 1 && we == xe) {
|
||||
CmpAAFn cmp = ne ? CMP_AA_FN(ne,we) : CMP_AA_FN(eq,we);
|
||||
CMP_AA_CALL(cmp, rp, wp, xp, len);
|
||||
} else {
|
||||
if (we==el_bit || xe==el_bit) return bi_N;
|
||||
EqFnObj eqfn = EQFN_GET(we, xe);
|
||||
for (usz i = 0; i < len; i++) {
|
||||
bitp_set(rp, i, ne^EQFN_CALL(eqfn, wp, xp, csz));
|
||||
wp += ww; xp += xw;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
B shape_c1(B, B);
|
||||
B cell_c1(Md1D* d, B x) { B f = d->f;
|
||||
if (isAtm(x) || RNK(x)==0) {
|
||||
@ -397,6 +427,13 @@ B cell_c2(Md1D* d, B w, B x) { B f = d->f;
|
||||
usz cam = SH(w)[0];
|
||||
if (cam==0) return cell2_empty(f, w, x, wr, xr);
|
||||
if (cam != SH(x)[0]) thrF("˘: Leading axis of arguments not equal (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", w, x);
|
||||
if (isFun(f)) {
|
||||
u8 rtid = v(f)->flags-1;
|
||||
if (rtid==n_feq || rtid==n_fne) {
|
||||
B r = match_cells(rtid!=n_feq, w, x, wr, xr, cam);
|
||||
if (!q_N(r)) { decG(w); decG(x); return r; }
|
||||
}
|
||||
}
|
||||
S_SLICES(w) S_SLICES(x)
|
||||
M_HARR(r, cam);
|
||||
for (usz i=0,wp=0,xp=0; i<cam; i++,wp+=w_csz,xp+=x_csz) HARR_ADD(r, i, c2(f, SLICE(w, wp), SLICE(x, xp)));
|
||||
|
||||
119
src/core/stuff.c
119
src/core/stuff.c
@ -402,37 +402,75 @@ NOINLINE bool atomEqualF(B w, B x) {
|
||||
decG(wd);decG(xd); return true;
|
||||
}
|
||||
|
||||
// Functions in eqFns compare segments for matching
|
||||
// data argument comes from eqFnData
|
||||
static const u8 n = 99;
|
||||
u8 eqFnData[] = { // for the main diagonal, amount to shift length by; otherwise, whether to swap arguments
|
||||
0,0,0,0,0,n,n,n,
|
||||
1,0,0,0,0,n,n,n,
|
||||
1,1,1,0,0,n,n,n,
|
||||
1,1,1,2,0,n,n,n,
|
||||
1,1,1,1,0,n,n,n,
|
||||
n,n,n,n,n,0,0,0,
|
||||
n,n,n,n,n,1,1,0,
|
||||
n,n,n,n,n,1,1,2,
|
||||
};
|
||||
|
||||
#if SINGELI
|
||||
#define F(X) avx2_equal_##X
|
||||
#define SINGELI_FILE equal
|
||||
#include "../utils/includeSingeli.h"
|
||||
|
||||
typedef bool (*EqFn)(void* a, void* b, u64 l, u64 data);
|
||||
bool notEq(void* a, void* b, u64 l, u64 data) { return false; }
|
||||
|
||||
#define F(X) avx2_equal_##X
|
||||
EqFn eqFns[] = {
|
||||
F(1_1), F(1_8), F(1_16), F(1_32), F(1_f64), notEq, notEq, notEq,
|
||||
F(1_8), F(8_8), F(s8_16), F(s8_32), F(s8_f64), notEq, notEq, notEq,
|
||||
F(1_16), F(s8_16), F(8_8), F(s16_32), F(s16_f64), notEq, notEq, notEq,
|
||||
F(1_32), F(s8_32), F(s16_32), F(8_8), F(s32_f64), notEq, notEq, notEq,
|
||||
F(1_f64), F(s8_f64), F(s16_f64), F(s32_f64), F(f64_f64), notEq, notEq, notEq,
|
||||
notEq, notEq, notEq, notEq, notEq, F(8_8), F(u8_16), F(u8_32),
|
||||
notEq, notEq, notEq, notEq, notEq, F(u8_16), F(8_8), F(u16_32),
|
||||
notEq, notEq, notEq, notEq, notEq, F(u8_32), F(u16_32), F(8_8),
|
||||
};
|
||||
#undef F
|
||||
static const u8 n = 99;
|
||||
u8 eqFnData[] = { // for the main diagonal, amount to shift length by; otherwise, whether to swap arguments
|
||||
0,0,0,0,0,n,n,n,
|
||||
1,0,0,0,0,n,n,n,
|
||||
1,1,1,0,0,n,n,n,
|
||||
1,1,1,2,0,n,n,n,
|
||||
1,1,1,1,0,n,n,n,
|
||||
n,n,n,n,n,0,0,0,
|
||||
n,n,n,n,n,1,1,0,
|
||||
n,n,n,n,n,1,1,2,
|
||||
};
|
||||
#else
|
||||
#define F(X) equal_##X
|
||||
bool F(1_1)(void* w, void* x, u64 l, u64 d) {
|
||||
u64* wp = w; u64* xp = x;
|
||||
usz q = l/64;
|
||||
for (usz i=0; i<q; i++) if (wp[i] != xp[i]) return false;
|
||||
usz r = (-l)%64; return r==0 || (wp[q]^xp[q])<<r == 0;
|
||||
}
|
||||
#define DEF_EQ_U1(N, T) \
|
||||
bool F(1_##N)(void* w, void* x, u64 l, u64 d) { \
|
||||
if (d!=0) { void* t=w; w=x; x=t; } \
|
||||
u64* wp = w; T* xp = x; \
|
||||
for (usz i=0; i<l; i++) if (bitp_get(wp,i)!=xp[i]) return false; \
|
||||
return true; \
|
||||
}
|
||||
DEF_EQ_U1(8, i8)
|
||||
DEF_EQ_U1(16, i16)
|
||||
DEF_EQ_U1(32, i32)
|
||||
DEF_EQ_U1(f64, f64)
|
||||
#undef DEF_EQ_U1
|
||||
|
||||
#define DEF_EQ_I(NAME, S, T, INIT) \
|
||||
bool F(NAME)(void* w, void* x, u64 l, u64 d) { \
|
||||
INIT \
|
||||
S* wp = w; T* xp = x; \
|
||||
for (usz i=0; i<l; i++) if (wp[i]!=xp[i]) return false; \
|
||||
return true; \
|
||||
}
|
||||
#define DEF_EQ(N,S,T) DEF_EQ_I(N,S,T, if (d!=0) { void* t=w; w=x; x=t; })
|
||||
DEF_EQ_I(8_8, u8, u8, l<<=d;)
|
||||
DEF_EQ_I(f64_f64, f64, f64, )
|
||||
DEF_EQ(u8_16, u8, u16)
|
||||
DEF_EQ(u8_32, u8, u32) DEF_EQ(u16_32, u16, u32)
|
||||
DEF_EQ(s8_16, i8, i16)
|
||||
DEF_EQ(s8_32, i8, i32) DEF_EQ(s16_32, i16, i32)
|
||||
DEF_EQ(s8_f64, i8, f64) DEF_EQ(s16_f64, i16, f64) DEF_EQ(s32_f64, i32, f64)
|
||||
#undef DEF_EQ_I
|
||||
#undef DEF_EQ
|
||||
#endif
|
||||
bool notEq(void* a, void* b, u64 l, u64 data) { return false; }
|
||||
EqFn eqFns[] = {
|
||||
F(1_1), F(1_8), F(1_16), F(1_32), F(1_f64), notEq, notEq, notEq,
|
||||
F(1_8), F(8_8), F(s8_16), F(s8_32), F(s8_f64), notEq, notEq, notEq,
|
||||
F(1_16), F(s8_16), F(8_8), F(s16_32), F(s16_f64), notEq, notEq, notEq,
|
||||
F(1_32), F(s8_32), F(s16_32), F(8_8), F(s32_f64), notEq, notEq, notEq,
|
||||
F(1_f64), F(s8_f64), F(s16_f64), F(s32_f64), F(f64_f64), notEq, notEq, notEq,
|
||||
notEq, notEq, notEq, notEq, notEq, F(8_8), F(u8_16), F(u8_32),
|
||||
notEq, notEq, notEq, notEq, notEq, F(u8_16), F(8_8), F(u16_32),
|
||||
notEq, notEq, notEq, notEq, notEq, F(u8_32), F(u16_32), F(8_8),
|
||||
};
|
||||
#undef F
|
||||
|
||||
NOINLINE bool equalSlow(B w, B x, usz ia);
|
||||
NOINLINE bool equal(B w, B x) { // doesn't consume
|
||||
@ -455,29 +493,10 @@ NOINLINE bool equal(B w, B x) { // doesn't consume
|
||||
u8 we = TI(w,elType);
|
||||
u8 xe = TI(x,elType);
|
||||
|
||||
#if SINGELI
|
||||
if (we<=el_c32 && xe<=el_c32) { // remove & pass a(w) and a(x) to fn so it can do basic loop
|
||||
u64 idx = we*8 + xe;
|
||||
return eqFns[idx](tyany_ptr(w), tyany_ptr(x), ia, eqFnData[idx]);
|
||||
}
|
||||
#else
|
||||
if (((we==el_f64 | we==el_i32) && (xe==el_f64 | xe==el_i32))) {
|
||||
if (we==el_i32) { i32* wp = i32any_ptr(w);
|
||||
if(xe==el_i32) { i32* xp = i32any_ptr(x); for (usz i = 0; i < ia; i++) if(wp[i]!=xp[i]) return false; }
|
||||
else { f64* xp = f64any_ptr(x); for (usz i = 0; i < ia; i++) if(wp[i]!=xp[i]) return false; }
|
||||
} else { f64* wp = f64any_ptr(w);
|
||||
if(xe==el_i32) { i32* xp = i32any_ptr(x); for (usz i = 0; i < ia; i++) if(wp[i]!=xp[i]) return false; }
|
||||
else { f64* xp = f64any_ptr(x); for (usz i = 0; i < ia; i++) if(wp[i]!=xp[i]) return false; }
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (we==el_c32 && xe==el_c32) {
|
||||
u32* wp = c32any_ptr(w);
|
||||
u32* xp = c32any_ptr(x);
|
||||
for (usz i = 0; i < ia; i++) if(wp[i]!=xp[i]) return false;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
if (we<=el_c32 && xe<=el_c32) { // remove & pass a(w) and a(x) to fn so it can do basic loop
|
||||
usz idx = EQFN_INDEX(we, xe);
|
||||
return eqFns[idx](tyany_ptr(w), tyany_ptr(x), ia, eqFnData[idx]);
|
||||
}
|
||||
return equalSlow(w, x, ia);
|
||||
}
|
||||
bool equalSlow(B w, B x, usz ia) {
|
||||
|
||||
@ -35,4 +35,13 @@ CMP_DEF(le, AS);
|
||||
#define CMP_AA_IMM(FN, ELT, WHERE, WP, XP, LEN) CMP_AA_CALL(CMP_AA_FN(FN, ELT), WHERE, WP, XP, LEN)
|
||||
#define CMP_AS_IMM(FN, ELT, WHERE, WP, X, LEN) CMP_AS_CALL(CMP_AS_FN(FN, ELT), WHERE, WP, X, LEN)
|
||||
|
||||
// Check if the l elements starting at a and b match
|
||||
typedef bool (*EqFn)(void* a, void* b, u64 l, u64 data);
|
||||
extern EqFn eqFns[];
|
||||
extern u8 eqFnData[];
|
||||
#define EQFN_INDEX(W_ELT, X_ELT) ((W_ELT)*8 + (X_ELT))
|
||||
typedef struct { EqFn fn; u8 data; } EqFnObj;
|
||||
#define EQFN_GET(W_ELT, X_ELT) ({ u8 eqfn_i_ = EQFN_INDEX(W_ELT, X_ELT); (EqFnObj){.fn=eqFns[eqfn_i_], .data=eqFnData[eqfn_i_]}; })
|
||||
#define EQFN_CALL(FN, W, X, L) (FN).fn(W, X, L, (FN).data)
|
||||
|
||||
void bit_negatePtr(u64* rp, u64* xp, usz count); // count is number of u64-s
|
||||
|
||||
Loading…
Reference in New Issue
Block a user