move non-Singeli comparisons to function lookup

This commit is contained in:
dzaima 2022-10-27 18:49:56 +03:00
parent 8d6172126c
commit bed2708537
3 changed files with 117 additions and 118 deletions

View File

@ -1,6 +1,16 @@
#include "../core.h"
#include "../utils/each.h"
static NOINLINE void fillBits(u64* dst, u64 sz, bool v) {
u64 x = 0-(u64)v;
u64 am = (sz+63)/64; assert(am>0);
for (usz i = 0; i < am; i++) dst[i] = x;
}
static NOINLINE void fillBitsDec(u64* dst, u64 sz, bool v, u64 x) {
dec(b(x));
fillBits(dst, sz, v);
}
static NOINLINE u8 aMakeEq(B* w, B* x, u8 we, u8 xe) { // returns el_MAX if failed
B* p = we<xe?w:x;
B s = *p;
@ -38,68 +48,118 @@ CMP_REC(ne, ne, swapped=0;)
#define gt_rec(S, W, X) lt_rec(!S, X, W)
#undef CMP_REC
#if SINGELI
#include "../singeli/c/cmp.c"
#else
typedef void (*CmpAAFn)(u64*, void*, void*, u64);
typedef void (*CmpASFn)(u64*, void*, u64, u64);
#define CMPFN(A,F,S,T) A##_##F##S##_##T
#define FN_LUT(B,A,F,S) static const Cmp##S##Fn B##_##F##S[] = {CMPFN(A,F,S,u1), CMPFN(A,F,S,i8), CMPFN(A,F,S,i16), CMPFN(A,F,S,i32), CMPFN(A,F,S,f64), CMPFN(A,F,S,u8), CMPFN(A,F,S,u16), CMPFN(A,F,S,u32)}
#define AL(X) u64* rp; B r = m_bitarrc(&rp, X); usz ria=IA(r); usz bia = BIT_N(ria);
#define CMP_AA(CN, CR, NAME, OP, BX, PRE) NOINLINE B NAME##_AA(i32 swapped, B w, B x) { PRE \
#if SINGELI
#include "../singeli/c/cmp.c"
#else
#define CMP_AA0(N, T, BODY) void base_##N##AA##_##T(u64* r, void* w, void* x, u64 l) { BODY }
#define CMP_AA1(N, T, OP) CMP_AA0(N, T, for (usz i=0; i<l; i++) bitp_set(r, i, ((T*)w)[i] OP ((T*)x)[i]);)
#define CMP_AA(N, OP, BX) \
CMP_AA0(N, u1, ({usz bia = BIT_N(l); for (usz i=0; i<bia; i++) { u64 wv=((u64*)w)[i], xv=((u64*)x)[i]; ((u64*)r)[i] = BX; }});) \
CMP_AA1(N, i8, OP) CMP_AA1(N, i16, OP) CMP_AA1(N, i32, OP) CMP_AA1(N, f64, OP) \
const CmpAAFn base_##N##AA##_u32 = base_##N##AA##_i32;
#define CMP_AA_C0(N, OP) const CmpAAFn base_##N##AA##_u8 = base_##N##AA##_i8; const CmpAAFn base_##N##AA##_u16 = base_##N##AA##_i16;
#define CMP_AA_C1(N, OP) CMP_AA1(N, u8, OP) CMP_AA1(N, u16, OP)
CMP_AA(eq, ==, ~wv ^ xv) CMP_AA_C0(eq, ==)
CMP_AA(ne, !=, wv ^ xv) CMP_AA_C0(ne, !=)
CMP_AA(gt, > , wv & ~xv) CMP_AA_C1(gt, > )
CMP_AA(ge, >=, wv | ~xv) CMP_AA_C1(ge, >=)
#undef CMP_AA
#define CMP_SLOW(T, GW) void cmp_slow_##T(void* r, void* w, B x, u64 l, BBB2B fn) { \
assert(l>0); incBy(x,l-1); \
for (usz i=0; i<l; i++) bitp_set(r, i, o2bG(fn(m_f64(0), GW, x))); \
}
#define CMP_SLOWi(T,M) CMP_SLOW(T, m_##M(((T*)w)[i]))
CMP_SLOW(u1,m_i32(bitp_get(w,i)))
CMP_SLOWi(i8,i32) CMP_SLOWi(i16,i32) CMP_SLOWi(i32,i32) CMP_SLOWi(f64,f64)
CMP_SLOWi(u8,c32) CMP_SLOWi(u16,c32) CMP_SLOWi(u32,c32)
static inline void cmp_fill_eq(u64* r, u64 l, u64 x) { fillBitsDec(r, l, 0, x); }
static inline void cmp_fill_ne(u64* r, u64 l, u64 x) { fillBitsDec(r, l, 1, x); }
#define CMP_TO_SLOW(N, T) cmp_slow_##T(r, w, x, l, N##_c2)
#define CMP_TO_FILL(N, T) cmp_fill_##N(r, l, xr)
#define CMP_SA0(N, T, Q, SLOW, BODY) void base_##N##AS##_##T(u64* r, void* w, u64 xr, u64 l) { \
assert(l>0); B x=b(xr); \
if (LIKELY(q_##Q(x))) BODY; \
else SLOW(N, T); \
}
#define CMP_SA1(N, T, Q, C, SLOW, OP) CMP_SA0(N, T, Q, SLOW, ({T xv = C(x); for (usz i=0; i<l; i++) bitp_set(r, i, ((T*)w)[i] OP xv);}))
#define CMP_SA(N, OP, SLOW, BX) \
CMP_SA0(N, u1, bit, SLOW, ({usz bia = BIT_N(l); u64 xv=bitx(x); for (usz i=0; i<bia; i++) { u64 wv=((u64*)w)[i]; ((u64*)r)[i] = BX; }})) \
CMP_SA1(N,i8,i8,o2iG,SLOW,OP) CMP_SA1(N,i16,i16,o2iG,SLOW,OP) CMP_SA1(N,i32,i32,o2iG,SLOW,OP) CMP_SA1(N,f64,f64,o2fG,SLOW,OP) \
CMP_SA1(N,u8,c8,o2cG,SLOW,OP) CMP_SA1(N,u16,c16,o2cG,SLOW,OP) CMP_SA1(N,u32,c32,o2cG,SLOW,OP)
CMP_SA(eq, ==, CMP_TO_FILL, ~wv^xv)
CMP_SA(ne, !=, CMP_TO_FILL, wv^xv)
CMP_SA(le, <=, CMP_TO_SLOW, ~wv | xv)
CMP_SA(ge, >=, CMP_TO_SLOW, wv | ~xv)
CMP_SA(lt, < , CMP_TO_SLOW, ~wv & xv)
CMP_SA(gt, > , CMP_TO_SLOW, wv & ~xv)
#undef CMP_SA
FN_LUT(cmp_fns, base, eq, AS); FN_LUT(cmp_fns, base, eq, AA);
FN_LUT(cmp_fns, base, ne, AS); FN_LUT(cmp_fns, base, ne, AA);
FN_LUT(cmp_fns, base, gt, AS); FN_LUT(cmp_fns, base, gt, AA);
FN_LUT(cmp_fns, base, ge, AS); FN_LUT(cmp_fns, base, ge, AA);
FN_LUT(cmp_fns, base, lt, AS);
FN_LUT(cmp_fns, base, le, AS);
#endif
#undef FN_LUT
#define AL(X) u64* rp; B r = m_bitarrc(&rp, X); usz ria=IA(r)
#define CMP_AA(CN, CR, NAME, PRE) NOINLINE B NAME##_AA(i32 swapped, B w, B x) { PRE \
u8 xe = TI(x, elType); if (xe==el_B) goto bad; \
u8 we = TI(w, elType); if (we==el_B) goto bad; \
if (RNK(w)==RNK(x)) { if (!eqShape(w, x)) thrF("%U: Expected equal shape prefix (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", swapped?CR:CN, swapped?x:w, swapped?w:x); \
if (we!=xe) { B tw=w,tx=x; \
we = aMakeEq(&tw, &tx, we, xe); \
if (we==el_MAX) goto bad; \
w=tw; x=tx; \
} \
AL(x) \
switch(we) { default: UD; \
case el_bit: { u64* wp=bitarr_ptr(w); u64* xp=bitarr_ptr(x); for(usz i=0;i<bia;i++) { u64 wv=wp[i]; u64 xv=xp[i]; rp[i]=BX; } break; } \
case el_i8 : { i8* wp=i8any_ptr (w); i8* xp=i8any_ptr (x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
case el_i16: { i16* wp=i16any_ptr(w); i16* xp=i16any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
case el_i32: { i32* wp=i32any_ptr(w); i32* xp=i32any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
case el_c8 : { u8* wp=c8any_ptr (w); u8* xp=c8any_ptr (x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
case el_c16: { u16* wp=c16any_ptr(w); u16* xp=c16any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
case el_c32: { u32* wp=c32any_ptr(w); u32* xp=c32any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
case el_f64: { f64* wp=f64any_ptr(w); f64* xp=f64any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wp[i] OP xp[i]); break; } \
} \
decG(w);decG(x); return r; \
} \
if (we!=xe) { B tw=w,tx=x; \
we = aMakeEq(&tw, &tx, we, xe); \
if (we==el_MAX) goto bad; \
w=tw; x=tx; \
} \
AL(x); \
if (ria) cmp_fns_##NAME##AA[we](rp, tyany_ptr(w), tyany_ptr(x), ria); \
decG(w);decG(x); return r; \
} \
bad: return NAME##_rec(swapped, w, x); \
}
CMP_AA("", "", le, <=, ~wv | xv, )
CMP_AA("<", ">", lt, < , ~wv & xv, )
CMP_AA("=", "?", eq, ==, ~wv^xv, swapped=0;)
CMP_AA("", "?", ne, !=, wv^xv, swapped=0;)
#define ge_AA(T, W, X) le_AA(!T, X, W)
#define gt_AA(T, W, X) lt_AA(!T, X, W)
CMP_AA("", "", ge, )
CMP_AA(">", "<", gt, )
CMP_AA("=", "?", eq, swapped=0;)
CMP_AA("", "?", ne, swapped=0;)
#define le_AA(T, W, X) ge_AA(!T, X, W)
#define lt_AA(T, W, X) gt_AA(!T, X, W)
#undef CMP_AA
#define CMP_SA(NAME, OP, BX, PRE) NOINLINE B NAME##_SA(i32 swapped, B w, B x) { PRE \
u8 xe = TI(x, elType); if (xe==el_B) goto bad; AL(x) \
switch(xe) { default: UD; \
case el_bit: { if (!q_bit(w)) break; u64 wv=bitx(w); u64* xp=bitarr_ptr(x); for(usz i=0;i<bia;i++) { u64 xv=xp[i]; rp[i]=BX; } decG(x); return r; } \
case el_i8: { if (!q_i8 (w)) break; i8 wv=o2iG(w); i8* xp=i8any_ptr (x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
case el_i16: { if (!q_i16(w)) break; i16 wv=o2iG(w); i16* xp=i16any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
case el_i32: { if (!q_i32(w)) break; i32 wv=o2iG(w); i32* xp=i32any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
case el_c8: { if (!q_c8 (w)) break; u8 wv=o2cG(w); u8* xp=c8any_ptr (x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
case el_c16: { if (!q_c16(w)) break; u16 wv=o2cG(w); u16* xp=c16any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
case el_c32: { if (!q_c32(w)) break; u32 wv=o2cG(w); u32* xp=c32any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
case el_f64: { if (!q_f64(w)) break; f64 wv=o2fG(w); f64* xp=f64any_ptr(x); for(usz i=0;i<ria;i++) bitp_set(rp,i,wv OP xp[i]); decG(x); return r; } \
} \
decG(r); \
#define CMP_SA(NAME, RNAME, PRE) B NAME##_SA(i32 swapped, B w, B x) { PRE \
u8 xe = TI(x, elType); if (xe==el_B) goto bad; \
AL(x); \
if (ria) cmp_fns_##RNAME##AS[xe](rp, tyany_ptr(x), w.u, ria); \
else dec(w); \
decG(x); return r; \
bad: return NAME##_rec(swapped, w, x); \
}
CMP_SA(eq, ==, ~wv^xv, swapped=0;)
CMP_SA(ne, !=, wv^xv, swapped=0;)
CMP_SA(le, <=, ~wv | xv, )
CMP_SA(ge, >=, wv | ~xv, )
CMP_SA(lt, < , ~wv & xv, )
CMP_SA(gt, > , wv & ~xv, )
CMP_SA(eq, eq, swapped=0;)
CMP_SA(ne, ne, swapped=0;)
CMP_SA(le, ge, )
CMP_SA(ge, le, )
CMP_SA(lt, gt, )
CMP_SA(gt, lt, )
#undef CMP_SA
#endif
#undef AL

View File

@ -2,82 +2,21 @@
#include "../../core.h"
#include "../../builtins.h"
static NOINLINE void fillBits(u64* dst, u64 sz, bool v) {
u64 x = 0-(u64)v;
u64 am = (sz+63)/64; assert(am>0);
for (usz i = 0; i < am; i++) dst[i] = x;
}
static NOINLINE void fillBitsDec(u64* dst, u64 sz, bool v, u64 x) {
dec(b(x));
fillBits(dst, sz, v);
}
extern bool please_tail_call_err;
static NOINLINE void cmp_err() { if (please_tail_call_err) thrM("Invalid comparison"); }
#define BCALL(N, X) N(b(X))
#define interp_f64(X) b(X).f
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
#include "../gen/cmp.c"
#pragma GCC diagnostic pop
typedef void (*CmpAAFn)(u64*, void*, void*, u64);
typedef void (*CmpASFn)(u64*, void*, u64, u64);
#define CMPFN(A,F,S,T) A##_##F##S##_##T
#define FN_LUT(A,F,S) static const Cmp##S##Fn lut_##A##_##F##S[] = {CMPFN(A,F,S,u1), CMPFN(A,F,S,i8), CMPFN(A,F,S,i16), CMPFN(A,F,S,i32), CMPFN(A,F,S,f64), CMPFN(A,F,S,u8), CMPFN(A,F,S,u16), CMPFN(A,F,S,u32)}
FN_LUT(cmp_fns, avx2, eq, AS); FN_LUT(cmp_fns, avx2, eq, AA);
FN_LUT(cmp_fns, avx2, ne, AS); FN_LUT(cmp_fns, avx2, ne, AA);
FN_LUT(cmp_fns, avx2, gt, AS); FN_LUT(cmp_fns, avx2, gt, AA);
FN_LUT(cmp_fns, avx2, ge, AS); FN_LUT(cmp_fns, avx2, ge, AA);
FN_LUT(cmp_fns, avx2, lt, AS);
FN_LUT(cmp_fns, avx2, le, AS);
FN_LUT(avx2, eq, AS); FN_LUT(avx2, eq, AA);
FN_LUT(avx2, ne, AS); FN_LUT(avx2, ne, AA);
FN_LUT(avx2, gt, AS); FN_LUT(avx2, gt, AA);
FN_LUT(avx2, ge, AS); FN_LUT(avx2, ge, AA);
FN_LUT(avx2, lt, AS);
FN_LUT(avx2, le, AS);
#undef FN_LUT
#define AL(X) u64* rp; B r = m_bitarrc(&rp, X); usz ria=IA(r)
#define CMP_AA(CN, CR, NAME, PRE) NOINLINE B NAME##_AA(i32 swapped, B w, B x) { PRE \
u8 xe = TI(x, elType); if (xe==el_B) goto bad; \
u8 we = TI(w, elType); if (we==el_B) goto bad; \
if (RNK(w)==RNK(x)) { if (!eqShape(w, x)) thrF("%U: Expected equal shape prefix (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", swapped?CR:CN, swapped?x:w, swapped?w:x); \
if (we!=xe) { B tw=w,tx=x; \
we = aMakeEq(&tw, &tx, we, xe); \
if (we==el_MAX) goto bad; \
w=tw; x=tx; \
} \
AL(x); \
if (ria) lut_avx2_##NAME##AA[we](rp, tyany_ptr(w), tyany_ptr(x), ria); \
decG(w);decG(x); return r; \
} \
bad: return NAME##_rec(swapped, w, x); \
}
CMP_AA("", "", ge, )
CMP_AA(">", "<", gt, )
CMP_AA("=", "?", eq, swapped=0;)
CMP_AA("", "?", ne, swapped=0;)
#define le_AA(T, W, X) ge_AA(!T, X, W)
#define lt_AA(T, W, X) gt_AA(!T, X, W)
#undef CMP_AA
#define CMP_SA(NAME, RNAME, PRE) B NAME##_SA(i32 swapped, B w, B x) { PRE \
u8 xe = TI(x, elType); if (xe==el_B) goto bad; \
AL(x); \
if (ria) lut_avx2_##RNAME##AS[xe](rp, tyany_ptr(x), w.u, ria); \
else dec(w); \
decG(x); return r; \
bad: return NAME##_rec(swapped, w, x); \
}
CMP_SA(eq, eq, swapped=0;)
CMP_SA(ne, ne, swapped=0;)
CMP_SA(le, ge, )
CMP_SA(ge, le, )
CMP_SA(lt, gt, )
CMP_SA(gt, lt, )
#undef CMP_SA
#undef AL

View File

@ -19,7 +19,7 @@ arrs ← •internal.Squeeze¨ ⟨
65 (•MakeRand 2).Range 2
90 (•MakeRand 2).Range 2
atms 1¯10¯0¯∞nnn@'l''⍉''𝕩'{a1}
atms 1¯10¯0¯∞nnn@'l''⍉''𝕩'{a1}+{𝔽}
atms - 1020
atms - (¯0.9¯0.100.10.9+2-5) + 235
atms @+12865536+2-5