Merge pull request #135 from mlochbaum/scan
Optimize common high-rank scans
This commit is contained in:
commit
bb3bb1b1d3
@ -1,22 +1,31 @@
|
|||||||
// Scan (`)
|
// Scan (`)
|
||||||
// Empty 𝕩, and length 1 if no 𝕨: return 𝕩
|
// Empty 𝕩, and length 1 if no 𝕨: return 𝕩
|
||||||
// Generic operand:
|
// Generic argument:
|
||||||
// Constant: copy
|
// Constant: copy
|
||||||
// ⊢ identity, ⊣ reshape 𝕨 or first cell
|
// ⊢ identity, ⊣ reshape 𝕨 or first cell
|
||||||
// Boolean operand, rank 1:
|
// Boolean argument, stride 1:
|
||||||
// + AVX2 expansion (SHOULD have better generic, add SSE, NEON)
|
// + AVX2 expansion (SHOULD have better generic, add SSE, NEON)
|
||||||
// ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search)
|
// ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search)
|
||||||
// ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul)
|
// ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul)
|
||||||
// < SWAR
|
// < SWAR
|
||||||
// =≤≥>- in terms of ≠<∨∧+ with adjustments
|
// =≤≥>- in terms of ≠<∨∧+ with adjustments
|
||||||
// Arithmetic operand, rank 1:
|
// Numeric argument, stride 1:
|
||||||
// ⌈⌊ Scalar, SIMD in log(vector width) steps
|
// ⌈⌊ Scalar, SIMD in log(vector width) steps
|
||||||
// Check in 6-vector blocks to quickly write result if constant
|
// Check in 6-vector blocks to quickly write result if constant
|
||||||
// + Overflow-checked scalar or AVX2
|
// + Overflow-checked scalar or AVX2
|
||||||
// Ad-hoc boolean-valued handling for ≠∨
|
// Ad-hoc boolean-valued handling for ≠∨
|
||||||
// SHOULD extend rank 1 special cases to cell bound 1
|
// Higher-rank arithmetic:
|
||||||
// Higher-rank arithmetic, non-tiny cells: apply operand cell-wise
|
// Boolean ≠∨∧ and synonyms: SWAR; ⌊⌈+: SIMD with shuffle/permute
|
||||||
// SHOULD have dedicated high-rank scan optimizations
|
// Stride <word/vector: power-of-two (times stride) shifts
|
||||||
|
// COULD vectorize small-stride boolean scans
|
||||||
|
// ≠, divisor of 64: CLMUL
|
||||||
|
// Stride 1 to 2 words: result in register instead of re-reading
|
||||||
|
// Read and write at same alignment unless stride is large
|
||||||
|
// Large-stride cases auto-vectorize (except + overflow check)
|
||||||
|
// Overflow check for +, widen and retry on failure
|
||||||
|
// =` as ≠`⌾¬
|
||||||
|
// SHOULD optimize high-rank dyadic scan for recognized operands
|
||||||
|
// Other arithmetic, non-tiny cells: apply operand cell-wise
|
||||||
|
|
||||||
// Scan with rank (`˘ or `⎉k)
|
// Scan with rank (`˘ or `⎉k)
|
||||||
// SHOULD optimize dyadic scan with rank
|
// SHOULD optimize dyadic scan with rank
|
||||||
@ -70,7 +79,7 @@ B mul_c2(B, B, B);
|
|||||||
|
|
||||||
B scan_ne(B x, u64 p, u64 ia) { // consumes x
|
B scan_ne(B x, u64 p, u64 ia) { // consumes x
|
||||||
u64* xp = bitany_ptr(x);
|
u64* xp = bitany_ptr(x);
|
||||||
u64* rp; B r=m_bitarrv(&rp,ia);
|
u64* rp; B r=m_bitarrc(&rp,x);
|
||||||
#if SINGELI
|
#if SINGELI
|
||||||
si_scan_ne(p, xp, rp, BIT_N(ia));
|
si_scan_ne(p, xp, rp, BIT_N(ia));
|
||||||
#if USE_VALGRIND
|
#if USE_VALGRIND
|
||||||
@ -97,14 +106,14 @@ B scan_eq(B x, u64 ia) { // consumes x
|
|||||||
|
|
||||||
static B scan_or(B x, u64 ia) { // consumes x
|
static B scan_or(B x, u64 ia) { // consumes x
|
||||||
u64* xp = bitany_ptr(x);
|
u64* xp = bitany_ptr(x);
|
||||||
u64* rp; B r=m_bitarrv(&rp,ia);
|
u64* rp; B r=m_bitarrc(&rp,x);
|
||||||
usz n=BIT_N(ia); u64 xi; usz i=0;
|
usz n=BIT_N(ia); u64 xi; usz i=0;
|
||||||
while (i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0;
|
while (i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0;
|
||||||
decG(x); return FL_SET(r, fl_asc|fl_squoze);
|
decG(x); return FL_SET(r, fl_asc|fl_squoze);
|
||||||
}
|
}
|
||||||
static B scan_and(B x, u64 ia) { // consumes x
|
static B scan_and(B x, u64 ia) { // consumes x
|
||||||
u64* xp = bitany_ptr(x);
|
u64* xp = bitany_ptr(x);
|
||||||
u64* rp; B r=m_bitarrv(&rp,ia);
|
u64* rp; B r=m_bitarrc(&rp,x);
|
||||||
usz n=BIT_N(ia); u64 xi; usz i=0;
|
usz n=BIT_N(ia); u64 xi; usz i=0;
|
||||||
while (i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL;
|
while (i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL;
|
||||||
decG(x); return FL_SET(r, fl_dsc|fl_squoze);
|
decG(x); return FL_SET(r, fl_dsc|fl_squoze);
|
||||||
@ -130,7 +139,7 @@ B scan_add_bool(B x, u64 ia) { // consumes x
|
|||||||
decG(ones);
|
decG(ones);
|
||||||
r = mut_fv(r0);
|
r = mut_fv(r0);
|
||||||
} else {
|
} else {
|
||||||
void* rp = m_tyarrv(&r, elWidth(re), ia, el2t(re));
|
void* rp = m_tyarrc(&r, elWidth(re), x, el2t(re));
|
||||||
#define SUM_BITWISE(T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } }
|
#define SUM_BITWISE(T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } }
|
||||||
#if SINGELI
|
#if SINGELI
|
||||||
#define SUM(W,T) si_bcs##W(xp, rp, ia);
|
#define SUM(W,T) si_bcs##W(xp, rp, ia);
|
||||||
@ -156,7 +165,7 @@ B scan_add_bool(B x, u64 ia) { // consumes x
|
|||||||
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
|
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
|
||||||
#endif
|
#endif
|
||||||
#define MM_CASE(T,N,C,I) \
|
#define MM_CASE(T,N,C,I) \
|
||||||
case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,I); break; }
|
case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,I); break; }
|
||||||
#define MINMAX(NAME,C,INIT,BIT,ORD) \
|
#define MINMAX(NAME,C,INIT,BIT,ORD) \
|
||||||
B r; switch (xe) { default:UD; \
|
B r; switch (xe) { default:UD; \
|
||||||
case el_bit: return scan_##BIT(x, ia); \
|
case el_bit: return scan_##BIT(x, ia); \
|
||||||
@ -174,7 +183,7 @@ B scan_max_num(B x, u8 xe, u64 ia) { MINMAX(max,>,MIN,or ,asc) }
|
|||||||
#define MM2_ICASE(T,N,C,I) \
|
#define MM2_ICASE(T,N,C,I) \
|
||||||
case el_##T : { \
|
case el_##T : { \
|
||||||
if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \
|
if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \
|
||||||
T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,wv); \
|
T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,wv); \
|
||||||
break; }
|
break; }
|
||||||
#define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \
|
#define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \
|
||||||
i32 wv=0; if (q_i32(w)) { wv=o2fG(w); } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \
|
i32 wv=0; if (q_i32(w)) { wv=o2fG(w); } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \
|
||||||
@ -195,7 +204,7 @@ SHOULD_INLINE B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0
|
|||||||
|
|
||||||
static B scan_lt(B x, u64 p, usz ia) {
|
static B scan_lt(B x, u64 p, usz ia) {
|
||||||
u64* xp = bitany_ptr(x);
|
u64* xp = bitany_ptr(x);
|
||||||
u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia);
|
u64* rp; B r=m_bitarrc(&rp,x); usz n=BIT_N(ia);
|
||||||
u64 m = 0x5555555555555555;
|
u64 m = 0x5555555555555555;
|
||||||
for (usz i=0; i<n; i++) {
|
for (usz i=0; i<n; i++) {
|
||||||
u64 x = xp[i];
|
u64 x = xp[i];
|
||||||
@ -208,7 +217,7 @@ static B scan_lt(B x, u64 p, usz ia) {
|
|||||||
|
|
||||||
static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
||||||
assert(xe!=el_bit && elNum(xe));
|
assert(xe!=el_bit && elNum(xe));
|
||||||
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
|
B r; void* rp = m_tyarrc(&r, xe==el_f64? sizeof(f64) : sizeof(i32), x, xe==el_f64? t_f64arr : t_i32arr);
|
||||||
#if SINGELI
|
#if SINGELI
|
||||||
switch(xe) { default:UD;
|
switch(xe) { default:UD;
|
||||||
case el_i8: { if (!q_fi32(r0) || si_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
|
case el_i8: { if (!q_fi32(r0) || si_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
|
||||||
@ -217,8 +226,8 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
|||||||
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||||
}
|
}
|
||||||
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
|
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
|
||||||
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
|
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrc(&rp, x); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||||
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
|
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrc(&rp, x); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||||
#else
|
#else
|
||||||
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||||
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||||
@ -226,7 +235,7 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
|||||||
if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||||
base:;
|
base:;
|
||||||
decG(r);
|
decG(r);
|
||||||
f64* rp2; r = m_f64arrv(&rp2, ia); rp = rp2;
|
f64* rp2; r = m_f64arrc(&rp2, x); rp = rp2;
|
||||||
x = toF64Any(x);
|
x = toF64Any(x);
|
||||||
goto res_float;
|
goto res_float;
|
||||||
#endif
|
#endif
|
||||||
@ -234,10 +243,9 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
|||||||
|
|
||||||
extern B scan_arith(B f, B w, B x, usz* xsh); // from cells.c
|
extern B scan_arith(B f, B w, B x, usz* xsh); // from cells.c
|
||||||
B scan_c1(Md1D* d, B x) { B f = d->f;
|
B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||||
if (isAtm(x) || RNK(x)==0) thrM("𝔽`𝕩: 𝕩 cannot have rank 0");
|
if (isAtm(x)) { unit: thrM("𝔽`𝕩: 𝕩 cannot have rank 0"); }
|
||||||
ur xr = RNK(x);
|
usz ia = IA(x); if (ia <= 1) { if (ia==1 && RNK(x)==0) goto unit; return x; }
|
||||||
usz ia = IA(x);
|
usz n = *SH(x); if (n <= 1) return x;
|
||||||
if (*SH(x)<=1 || ia==0) return x;
|
|
||||||
if (RARE(!isFun(f))) {
|
if (RARE(!isFun(f))) {
|
||||||
if (isMd(f)) thrM("Calling a modifier");
|
if (isMd(f)) thrM("Calling a modifier");
|
||||||
B xf = getFillR(x);
|
B xf = getFillR(x);
|
||||||
@ -257,7 +265,50 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
|||||||
Arr* r = TI(x,slice)(x, 0, csz);
|
Arr* r = TI(x,slice)(x, 0, csz);
|
||||||
return C2(shape, s, taga(r));
|
return C2(shape, s, taga(r));
|
||||||
}
|
}
|
||||||
if (!(xr==1 && xe<=el_f64)) goto base;
|
if (xe > el_f64) goto base;
|
||||||
|
if (ia != n) { // csz != 1
|
||||||
|
#if SINGELI
|
||||||
|
usz csz = arr_csz(x);
|
||||||
|
i8 t = -1; bool neg = 0;
|
||||||
|
if (xe==el_bit) switch (rtid) {
|
||||||
|
CASE_N_OR: t=0; break;
|
||||||
|
CASE_N_AND: t=1; break;
|
||||||
|
case n_eq: neg=1; case n_ne: t=2; break;
|
||||||
|
}
|
||||||
|
if (t != -1) {
|
||||||
|
if (neg) x = bit_negate(x);
|
||||||
|
u64* rp; B r=m_bitarrc(&rp,x);
|
||||||
|
si_scan_bool_stride[t](bitany_ptr(x), rp, ia, csz);
|
||||||
|
if (neg) r = bit_negate(r);
|
||||||
|
decG(x); return r;
|
||||||
|
}
|
||||||
|
if (rtid==n_floor | rtid==n_ceil) {
|
||||||
|
// boolean was handled as CASE_N_AND
|
||||||
|
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
|
||||||
|
void* xp = tyany_ptr(x);
|
||||||
|
si_scan_stride_minmax[4*(rtid==n_ceil) + xe-el_i8](xp, rp, ia, csz);
|
||||||
|
decG(x); return r;
|
||||||
|
}
|
||||||
|
if (rtid==n_add) {
|
||||||
|
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xe=el_i8; }
|
||||||
|
restart:;
|
||||||
|
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
|
||||||
|
void* xp = tyany_ptr(x);
|
||||||
|
bool done = si_scan_stride_add[xe-el_i8](xp, rp, ia, csz);
|
||||||
|
if (!done) {
|
||||||
|
decG(r);
|
||||||
|
switch (++xe) { default: UD;
|
||||||
|
case el_i16: x = taga(cpyI16Arr(x)); break;
|
||||||
|
case el_i32: x = taga(cpyI32Arr(x)); break;
|
||||||
|
case el_f64: x = taga(cpyF64Arr(x)); break;
|
||||||
|
}
|
||||||
|
goto restart;
|
||||||
|
}
|
||||||
|
decG(x); return r;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
goto base;
|
||||||
|
}
|
||||||
|
|
||||||
if (xe==el_bit) switch (rtid) { default: goto base;
|
if (xe==el_bit) switch (rtid) { default: goto base;
|
||||||
case n_add: return scan_add_bool(x, ia); // +
|
case n_add: return scan_add_bool(x, ia); // +
|
||||||
@ -278,7 +329,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
|||||||
if (!elInt(xe)) goto base;
|
if (!elInt(xe)) goto base;
|
||||||
f64 x0 = o2fG(IGetU(x,0));
|
f64 x0 = o2fG(IGetU(x,0));
|
||||||
if (!q_fbit(x0)) goto base;
|
if (!q_fbit(x0)) goto base;
|
||||||
u64* rp; B r = m_bitarrv(&rp, ia);
|
u64* rp; B r = m_bitarrc(&rp, x);
|
||||||
bool c = x0;
|
bool c = x0;
|
||||||
rp[0] = c;
|
rp[0] = c;
|
||||||
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||||
@ -289,7 +340,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
|||||||
if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); }
|
if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); }
|
||||||
}
|
}
|
||||||
base:;
|
base:;
|
||||||
if (xr>1 && ia >= 6 * (u64)*SH(x) && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x));
|
if (ia!=n && ia >= 6 * (u64)n && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x));
|
||||||
SLOW2("𝕎` 𝕩", f, x);
|
SLOW2("𝕎` 𝕩", f, x);
|
||||||
B xf = getFillR(x);
|
B xf = getFillR(x);
|
||||||
|
|
||||||
@ -297,7 +348,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
|||||||
SGet(x)
|
SGet(x)
|
||||||
FC2 fc2 = c2fn(f);
|
FC2 fc2 = c2fn(f);
|
||||||
|
|
||||||
if (xr==1) {
|
if (ia == n) {
|
||||||
r.a[0] = Get(x,0);
|
r.a[0] = Get(x,0);
|
||||||
for (usz i=1; i<ia; i++) r.a[i] = fc2(f, inc(r.a[i-1]), Get(x,i));
|
for (usz i=1; i<ia; i++) r.a[i] = fc2(f, inc(r.a[i-1]), Get(x,i));
|
||||||
} else {
|
} else {
|
||||||
@ -327,7 +378,8 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
|
|||||||
u8 rtid = RTID(f);
|
u8 rtid = RTID(f);
|
||||||
if (rtid==n_rtack) { dec(w); return x; }
|
if (rtid==n_rtack) { dec(w); return x; }
|
||||||
if (rtid==n_ltack) return C2(shape, C1(fne, x), w);
|
if (rtid==n_ltack) return C2(shape, C1(fne, x), w);
|
||||||
if (!(xr==1 && elNum(xe) && xe<=el_f64)) goto base;
|
if (!(elNum(xe) && xe<=el_f64)) goto base;
|
||||||
|
if (xr!=1 && *SH(x)!=ia) goto base;
|
||||||
if (!isF64(w)) goto base;
|
if (!isF64(w)) goto base;
|
||||||
|
|
||||||
if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊
|
if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊
|
||||||
@ -350,7 +402,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
|
|||||||
if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitany_ptr(x)), ia);
|
if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitany_ptr(x)), ia);
|
||||||
if (!wBit || !elInt(xe)) goto base;
|
if (!wBit || !elInt(xe)) goto base;
|
||||||
bool c = o2bG(w);
|
bool c = o2bG(w);
|
||||||
u64* rp; B r = m_bitarrv(&rp, ia);
|
u64* rp; B r = m_bitarrc(&rp, x);
|
||||||
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||||
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||||
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||||
|
|||||||
@ -96,7 +96,7 @@ local def ml_exec{i, iter, vars0, bulk, M} = {
|
|||||||
|
|
||||||
# i0 - initial batch index; not used as begin because it's in a different scale compared to end
|
# i0 - initial batch index; not used as begin because it's in a different scale compared to end
|
||||||
def for_masked{bulk, i0}{vars,begin==0,end,iter} = {
|
def for_masked{bulk, i0}{vars,begin==0,end,iter} = {
|
||||||
l:u64 = end
|
l:u64 = promote{u64, end}
|
||||||
|
|
||||||
m:u64 = l / bulk
|
m:u64 = l / bulk
|
||||||
@for (i from i0 to m) ml_exec{i, iter, vars, bulk, mask_none}
|
@for (i from i0 to m) ml_exec{i, iter, vars, bulk, mask_none}
|
||||||
@ -128,7 +128,7 @@ def for_masked_pos{bulk}{vars,begin==0,end:L,iter} = {
|
|||||||
# end is scalar element count
|
# end is scalar element count
|
||||||
# index given is a tuple of batch indexes to process
|
# index given is a tuple of batch indexes to process
|
||||||
def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = {
|
def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = {
|
||||||
l:u64 = end
|
l:u64 = promote{u64, end}
|
||||||
|
|
||||||
m:u64 = l / bulk
|
m:u64 = l / bulk
|
||||||
if (unr==1) {
|
if (unr==1) {
|
||||||
|
|||||||
@ -98,6 +98,118 @@ def shift_first{c:V=[l]_, p:V} = {
|
|||||||
else blend_first{c, rotate_right{p}}
|
else blend_first{c, rotate_right{p}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Strided scans
|
||||||
|
fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz) : Ret = {
|
||||||
|
def minvalue{(f64)} = -1/0; def maxvalue{(f64)} = 1/0
|
||||||
|
def id = match (op) {
|
||||||
|
{(min)} => maxvalue; {(max)} => minvalue
|
||||||
|
{(+)} => ({_}=>0)
|
||||||
|
}
|
||||||
|
x:= *T~~xv; r:= *T~~rv
|
||||||
|
# Architecture determination
|
||||||
|
# Use largest vector width with a full-width shuffle
|
||||||
|
def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'}
|
||||||
|
def I = if (hasarch{'AVX2'} and T>=i32) [8]i32 else [16]i8
|
||||||
|
def [il]IE = I; def selI = shuf{IE, ...}
|
||||||
|
def wT = width{T}
|
||||||
|
def f = wT/width{IE}
|
||||||
|
def vl = width{I}/wT
|
||||||
|
def V = [vl]T
|
||||||
|
if (has_shuf and l < vl) {
|
||||||
|
# Small stride: power-of-two shifts
|
||||||
|
def small{k} = {
|
||||||
|
iv:= iota{I}; j:= I**cast_i{IE,l*f}
|
||||||
|
spr:= I**il - j + iv
|
||||||
|
def inds = @collect (k) {
|
||||||
|
v:= iv - (j &~ I~~(iv<j))
|
||||||
|
spr = selI{spr, v}
|
||||||
|
js:= j; j+= j
|
||||||
|
if (not same{op, +}) selI{., v}
|
||||||
|
else if (same{IE,i8}) selI{., iv - js}
|
||||||
|
else { m:= V~~(iv >= js); {x} => selI{x, v} & m }
|
||||||
|
}
|
||||||
|
c:= V**id{T}
|
||||||
|
@for_masked{vl} (x in tup{V, x}, r in tup{V, r}, M in 'm' over ia) {
|
||||||
|
xs:= fold{{v, i} => op{i{v}, v}, x, inds}
|
||||||
|
r = c = op{shuf{IE, c, spr}, xs}
|
||||||
|
check_over{M, x, r} # For +, infers other argument as r-x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (not (same{op,+} and V==[4]f64)) {
|
||||||
|
def max_k = lb{vl/2} # Divide by two from assuming l≥2
|
||||||
|
if (max_k<3 or l<4) small{max_k} else small{max_k-1} # l=2 and l=3 are the only cases needing the full max_k iterations; max_k<3 limits specialization to where it's significant
|
||||||
|
} else { # Non-associative!
|
||||||
|
c:= V**0
|
||||||
|
if (l==2) {
|
||||||
|
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||||
|
a:= c + shuf{x, 0,1,0,1}
|
||||||
|
c = a + shuf{x, 2,3,2,3}
|
||||||
|
r = blend{a, c, 0,0,1,1}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert{l==3}
|
||||||
|
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||||
|
a:= shuf{c, 1,1,2,3} + blend{x, V**0, 0,1,1,1}
|
||||||
|
r = c = x + shuf{a, 1,2,3,0}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
# Large stride: single shift, with saved register or memory
|
||||||
|
def op_chk{M, p, x} = { r:= op{p, x}; check_over{M, p, x, r}; r }
|
||||||
|
@for (r, x over l) r = x
|
||||||
|
if (has_shuf and l<256/(wT/8)) {
|
||||||
|
# Make sure to load the previous row data at the same alignment to not hit bad store-to-load forwarding
|
||||||
|
def [il]IE = I
|
||||||
|
q:= l%vl; fq:= cast_i{IE, q*f}
|
||||||
|
def rot = shuf{IE, ., (iota{I} - I**fq) & I**(il-1)}
|
||||||
|
bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv}
|
||||||
|
c:= V**id{T}
|
||||||
|
o:= l - q
|
||||||
|
if (l == 2*vl) { o = vl; bv = ~bv }
|
||||||
|
if (o == vl) {
|
||||||
|
p:= load{*V~~x}; store{*V~~r, 0, p}
|
||||||
|
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, M in 'm' over ia-o) {
|
||||||
|
p = rot{p}
|
||||||
|
r = op_chk{M, bl{c, p}, x}
|
||||||
|
c = p; p = r
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r}, M in 'm' over ia-o) {
|
||||||
|
q:= rot{p}
|
||||||
|
r = op_chk{M, bl{c, q}, x}
|
||||||
|
c = q
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (same{op, +} and T<=i32 and has_simd and (has_shuf or l>=vl)) {
|
||||||
|
def vl = arch_defvw/wT; def V = [vl]T
|
||||||
|
@for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r}, M in 'm' over ia-l) {
|
||||||
|
r = op_chk{M, p, x}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
@for (r, x, p in r-l over _ from l to ia) r = op_chk{0, p, x}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
1
|
||||||
|
}
|
||||||
|
def scan_stride_assoc{op, T} = scan_stride_assoc{op, T, void, {..._}=>{}}
|
||||||
|
def check_add_over{_, w:T, x:T, r:T} = { if ((w^r) & (x^r) < 0) return{0} }
|
||||||
|
def check_add_over{M, w:V=[_]E, x:V, r:V} = {
|
||||||
|
o:= (if (not hasarch{'X86_64'} or width{E}<=16) any_hom{M, subs{r,w} != x}
|
||||||
|
else any_top{M, (w^r) & (x^r)})
|
||||||
|
if (o) return{0}
|
||||||
|
}
|
||||||
|
def check_add_over{M, x, r} = check_add_over{M, r-x, x, r}
|
||||||
|
export_tab{'si_scan_stride_minmax',
|
||||||
|
flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}}
|
||||||
|
}
|
||||||
|
export_tab{'si_scan_stride_add', tup{
|
||||||
|
...each{scan_stride_assoc{+, ., u1, check_add_over}, tup{i8,i16,i32}},
|
||||||
|
scan_stride_assoc{+, f64, u1, {..._}=>{}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
# xor scan
|
# xor scan
|
||||||
def vec_prefix_byshift{op, sh} = {
|
def vec_prefix_byshift{op, sh} = {
|
||||||
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
|
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
|
||||||
@ -106,13 +218,13 @@ def vec_prefix_byshift{op, sh} = {
|
|||||||
def scan_word_ne = prefix_byshift{^, <<}
|
def scan_word_ne = prefix_byshift{^, <<}
|
||||||
def scan_words_ne = vec_prefix_byshift{^, <<}
|
def scan_words_ne = vec_prefix_byshift{^, <<}
|
||||||
|
|
||||||
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = {
|
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||||
@for (x, r over nw) {
|
@for (x, r over nw) {
|
||||||
r = c ^ scan_word_ne{x}
|
r = c ^ scan_word_ne{x}
|
||||||
c = -(r>>63) # repeat sign bit
|
c = -(r>>63) # repeat sign bit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
|
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||||
def vl = arch_defvw / 64
|
def vl = arch_defvw / 64
|
||||||
def V = [vl]u64
|
def V = [vl]u64
|
||||||
c := V**c0
|
c := V**c0
|
||||||
@ -123,7 +235,7 @@ fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
|
|||||||
c = broadcast_last{p}
|
c = broadcast_last{p}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
|
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:usz, mark:u64) : void = {
|
||||||
def V = [2]u64
|
def V = [2]u64
|
||||||
m := V**mark
|
m := V**mark
|
||||||
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
|
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
|
||||||
@ -144,10 +256,10 @@ fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64
|
|||||||
store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1}
|
store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
|
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||||
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
|
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
|
||||||
}
|
}
|
||||||
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
|
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||||
def V = [8]u64
|
def V = [8]u64
|
||||||
def sse{a} = make{[2]u64, a, 0}
|
def sse{a} = make{[2]u64, a, 0}
|
||||||
carry := sse{init}
|
carry := sse{init}
|
||||||
@ -358,10 +470,11 @@ def loose_mask_gen{V=[vl]T, l} = { # Slow, for ≠` only
|
|||||||
}
|
}
|
||||||
def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'}
|
def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'}
|
||||||
def loose_mask_gen{V=[vl](u64), l if has_vecshift} = {
|
def loose_mask_gen{V=[vl](u64), l if has_vecshift} = {
|
||||||
|
l64 := promote{u64, l}
|
||||||
q := -make{V, 64*iota{vl}} # distance to next row boundary
|
q := -make{V, 64*iota{vl}} # distance to next row boundary
|
||||||
def q_mod{} = { q+= V**l & -(q>>63) }
|
def q_mod{} = { q+= V**l64 & -(q>>63) }
|
||||||
def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l, q} }
|
def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l64, q} }
|
||||||
o:u64 = width{V}; while (o>l) { o-=l; q_mod{} }
|
o:u64 = width{V}; while (o>l64) { o-=l64; q_mod{} }
|
||||||
{} => {
|
{} => {
|
||||||
m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64
|
m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64
|
||||||
q-= V**o; q_mod{}
|
q-= V**o; q_mod{}
|
||||||
@ -560,7 +673,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
|
|||||||
c:u64 = 0 # carry
|
c:u64 = 0 # carry
|
||||||
while (1) {
|
while (1) {
|
||||||
i+= l; ii := iw; iw = cdiv{i, 64}
|
i+= l; ii := iw; iw = cdiv{i, 64}
|
||||||
scan_neq{}(c, x+ii, r+ii, promote{u64,iw-ii})
|
scan_neq{}(c, x+ii, r+ii, iw-ii)
|
||||||
if (i == nl) return{}
|
if (i == nl) return{}
|
||||||
s:= load{r, iw-1}
|
s:= load{r, iw-1}
|
||||||
q := i%64
|
q := i%64
|
||||||
@ -619,3 +732,49 @@ export{'si_scan_rows_and', scan_rows_andor{0}}
|
|||||||
export{'si_scan_rows_or', scan_rows_andor{1}}
|
export{'si_scan_rows_or', scan_rows_andor{1}}
|
||||||
export{'si_scan_rows_ne', scan_rows_neq}
|
export{'si_scan_rows_ne', scan_rows_neq}
|
||||||
export{'si_scan_rows_ltack', scan_rows_left}
|
export{'si_scan_rows_ltack', scan_rows_left}
|
||||||
|
|
||||||
|
|
||||||
|
# Strided boolean scans
|
||||||
|
fn scan_stride_bool_assoc{op}(x:*u64, r:*u64, nl:usz, l:usz) : void = {
|
||||||
|
assert{l > 1}
|
||||||
|
def {flip,opf} = if (same{op, &}) tup{~,|} else tup{{x}=>x,op} # such that identity of opf is 0
|
||||||
|
nw:= cdiv{nl, 64}
|
||||||
|
if (l <= 64) {
|
||||||
|
if (same{op, ^} and hasarch{'PCLMUL'} and (l & (l-1)) == 0) {
|
||||||
|
clmul_scan_ne_any{}(*void~~x, *void~~r, 0, nw, aligned_spaced_mask{l})
|
||||||
|
return{}
|
||||||
|
}
|
||||||
|
c:u64 = 0 # carry l bits, no matter the alignment
|
||||||
|
@for (r, x over nw) {
|
||||||
|
c = opf{flip{x}, c >> (64-l)}
|
||||||
|
s:= l; while (s < 64) { c = opf{c, c<<s}; s += s }
|
||||||
|
r = flip{c}
|
||||||
|
}
|
||||||
|
} else if (l < 128) {
|
||||||
|
q:= l%64
|
||||||
|
c:= flip{u64~~0}
|
||||||
|
p:= load{x,0}
|
||||||
|
store{r,0, p}
|
||||||
|
@for (r, x over _ from 1 to nw) {
|
||||||
|
r = op{x, c>>(64-q) | p<<q}
|
||||||
|
c = p; p = r
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fw:= cdiv{l, 64} # words in first cell
|
||||||
|
@for (r, x over fw) r = x
|
||||||
|
if (l%64 == 0) {
|
||||||
|
@for (r, x, p in r-fw over _ from fw to nw) r = op{x, p}
|
||||||
|
} else {
|
||||||
|
q:= l%64
|
||||||
|
c:= flip{u64~~0}
|
||||||
|
@for (r, x, p in r-(fw-1) over _ from fw-1 to nw) {
|
||||||
|
r = op{x, c>>(64-q) | p<<q}
|
||||||
|
c = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
export_tab{
|
||||||
|
'si_scan_bool_stride',
|
||||||
|
each{scan_stride_bool_assoc, tup{|, &, ^}}
|
||||||
|
}
|
||||||
|
|||||||
2
test/cases/fuzz/scan.bqn
Normal file
2
test/cases/fuzz/scan.bqn
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# tests scan on a matrix checking for overflow in tail elements
|
||||||
|
8‿16‿32 {e𝕊n‿w: m←2⋆e-2 ⋄ (m⊸+⌾(⟨¯1,w|n⟩⊸⊑) ⌽‿w⥊(n+w+1)⥊m) ≡ +` •internal.Squeeze ↑‿w⥊ (w⥊m)∾(n⥊0)∾m}⌜ 1+↕200‿65
|
||||||
@ -498,6 +498,9 @@ a←↕2 ⋄ ! "e" ≡ (↕10){b←a‿a‿a‿a‿a‿a‿a‿a‿a‿a ⋄
|
|||||||
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿2‿3 ≡ ≢𝕩)" % (2‿2⥊1)+`↕3‿2‿3
|
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿2‿3 ≡ ≢𝕩)" % (2‿2⥊1)+`↕3‿2‿3
|
||||||
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿3‿2 ≡ ≢𝕩)" % (2‿2⥊1)+`↕3‿3‿2
|
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿3‿2 ≡ ≢𝕩)" % (2‿2⥊1)+`↕3‿3‿2
|
||||||
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (⟨⟩ ≡ ≢𝕨, 3‿3 ≡ ≢𝕩)" % 2+`↕3‿3
|
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (⟨⟩ ≡ ≢𝕨, 3‿3 ≡ ≢𝕩)" % 2+`↕3‿3
|
||||||
|
%USE eqvar ⋄ %USE k ⋄ ⟨+,-,×,÷,⋆,√,⌊,⌈,¬,∧,∨,<,>,≠,=,≤,≥,⊣,⊢⟩ {f𝕊arr: ! ( F _k`_k ⥊arr) ≡ ⥊_k F` _eqvar arr}⌜ {𝕩∾≍˘¨𝕩} ≍˘¨ ⟨⋈1, 1‿0‿1⟩
|
||||||
|
%USE eqvar ⋄ %USE k ⋄ ⟨+,-,×,÷,⋆,√,⌊,⌈,¬,∧,∨,<,>,≠,=,≤,≥,⊣,⊢⟩ {f𝕊arr: ! (1 F _k`_k ⥊arr) ≡ ⥊_k (1¨⊏𝕩)⊸(F`) _eqvar arr}⌜ {𝕩∾≍˘¨𝕩} ≍˘¨ ⟨⋈1, 1‿0‿1⟩
|
||||||
|
|
||||||
|
|
||||||
# ´
|
# ´
|
||||||
!"𝔽´𝕩: 𝕩 must be a list (⟨⟩ ≡ ≢𝕩)" % +´0
|
!"𝔽´𝕩: 𝕩 must be a list (⟨⟩ ≡ ≢𝕩)" % +´0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user