Merge pull request #135 from mlochbaum/scan

Optimize common high-rank scans
This commit is contained in:
dzaima 2025-03-06 05:18:07 +02:00 committed by GitHub
commit bb3bb1b1d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 254 additions and 38 deletions

View File

@ -1,22 +1,31 @@
// Scan (`) // Scan (`)
// Empty 𝕩, and length 1 if no 𝕨: return 𝕩 // Empty 𝕩, and length 1 if no 𝕨: return 𝕩
// Generic operand: // Generic argument:
// Constant: copy // Constant: copy
// ⊢ identity, ⊣ reshape 𝕨 or first cell // ⊢ identity, ⊣ reshape 𝕨 or first cell
// Boolean operand, rank 1: // Boolean argument, stride 1:
// + AVX2 expansion (SHOULD have better generic, add SSE, NEON) // + AVX2 expansion (SHOULD have better generic, add SSE, NEON)
// ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search) // ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search)
// ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul) // ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul)
// < SWAR // < SWAR
// =≤≥>- in terms of ≠<∨∧+ with adjustments // =≤≥>- in terms of ≠<∨∧+ with adjustments
// Arithmetic operand, rank 1: // Numeric argument, stride 1:
// ⌈⌊ Scalar, SIMD in log(vector width) steps // ⌈⌊ Scalar, SIMD in log(vector width) steps
// Check in 6-vector blocks to quickly write result if constant // Check in 6-vector blocks to quickly write result if constant
// + Overflow-checked scalar or AVX2 // + Overflow-checked scalar or AVX2
// Ad-hoc boolean-valued handling for ≠∨ // Ad-hoc boolean-valued handling for ≠∨
// SHOULD extend rank 1 special cases to cell bound 1 // Higher-rank arithmetic:
// Higher-rank arithmetic, non-tiny cells: apply operand cell-wise // Boolean ≠∨∧ and synonyms: SWAR; ⌊⌈+: SIMD with shuffle/permute
// SHOULD have dedicated high-rank scan optimizations // Stride <word/vector: power-of-two (times stride) shifts
// COULD vectorize small-stride boolean scans
// ≠, divisor of 64: CLMUL
// Stride 1 to 2 words: result in register instead of re-reading
// Read and write at same alignment unless stride is large
// Large-stride cases auto-vectorize (except + overflow check)
// Overflow check for +, widen and retry on failure
// =` as ≠`⌾¬
// SHOULD optimize high-rank dyadic scan for recognized operands
// Other arithmetic, non-tiny cells: apply operand cell-wise
// Scan with rank (`˘ or `⎉k) // Scan with rank (`˘ or `⎉k)
// SHOULD optimize dyadic scan with rank // SHOULD optimize dyadic scan with rank
@ -70,7 +79,7 @@ B mul_c2(B, B, B);
B scan_ne(B x, u64 p, u64 ia) { // consumes x B scan_ne(B x, u64 p, u64 ia) { // consumes x
u64* xp = bitany_ptr(x); u64* xp = bitany_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia); u64* rp; B r=m_bitarrc(&rp,x);
#if SINGELI #if SINGELI
si_scan_ne(p, xp, rp, BIT_N(ia)); si_scan_ne(p, xp, rp, BIT_N(ia));
#if USE_VALGRIND #if USE_VALGRIND
@ -97,14 +106,14 @@ B scan_eq(B x, u64 ia) { // consumes x
static B scan_or(B x, u64 ia) { // consumes x static B scan_or(B x, u64 ia) { // consumes x
u64* xp = bitany_ptr(x); u64* xp = bitany_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia); u64* rp; B r=m_bitarrc(&rp,x);
usz n=BIT_N(ia); u64 xi; usz i=0; usz n=BIT_N(ia); u64 xi; usz i=0;
while (i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0; while (i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0;
decG(x); return FL_SET(r, fl_asc|fl_squoze); decG(x); return FL_SET(r, fl_asc|fl_squoze);
} }
static B scan_and(B x, u64 ia) { // consumes x static B scan_and(B x, u64 ia) { // consumes x
u64* xp = bitany_ptr(x); u64* xp = bitany_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia); u64* rp; B r=m_bitarrc(&rp,x);
usz n=BIT_N(ia); u64 xi; usz i=0; usz n=BIT_N(ia); u64 xi; usz i=0;
while (i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL; while (i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL;
decG(x); return FL_SET(r, fl_dsc|fl_squoze); decG(x); return FL_SET(r, fl_dsc|fl_squoze);
@ -130,7 +139,7 @@ B scan_add_bool(B x, u64 ia) { // consumes x
decG(ones); decG(ones);
r = mut_fv(r0); r = mut_fv(r0);
} else { } else {
void* rp = m_tyarrv(&r, elWidth(re), ia, el2t(re)); void* rp = m_tyarrc(&r, elWidth(re), x, el2t(re));
#define SUM_BITWISE(T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } } #define SUM_BITWISE(T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } }
#if SINGELI #if SINGELI
#define SUM(W,T) si_bcs##W(xp, rp, ia); #define SUM(W,T) si_bcs##W(xp, rp, ia);
@ -156,7 +165,7 @@ B scan_add_bool(B x, u64 ia) { // consumes x
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; } #define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
#endif #endif
#define MM_CASE(T,N,C,I) \ #define MM_CASE(T,N,C,I) \
case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,I); break; } case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,I); break; }
#define MINMAX(NAME,C,INIT,BIT,ORD) \ #define MINMAX(NAME,C,INIT,BIT,ORD) \
B r; switch (xe) { default:UD; \ B r; switch (xe) { default:UD; \
case el_bit: return scan_##BIT(x, ia); \ case el_bit: return scan_##BIT(x, ia); \
@ -174,7 +183,7 @@ B scan_max_num(B x, u8 xe, u64 ia) { MINMAX(max,>,MIN,or ,asc) }
#define MM2_ICASE(T,N,C,I) \ #define MM2_ICASE(T,N,C,I) \
case el_##T : { \ case el_##T : { \
if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \ if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \
T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,wv); \ T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,wv); \
break; } break; }
#define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \ #define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \
i32 wv=0; if (q_i32(w)) { wv=o2fG(w); } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \ i32 wv=0; if (q_i32(w)) { wv=o2fG(w); } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \
@ -195,7 +204,7 @@ SHOULD_INLINE B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0
static B scan_lt(B x, u64 p, usz ia) { static B scan_lt(B x, u64 p, usz ia) {
u64* xp = bitany_ptr(x); u64* xp = bitany_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); u64* rp; B r=m_bitarrc(&rp,x); usz n=BIT_N(ia);
u64 m = 0x5555555555555555; u64 m = 0x5555555555555555;
for (usz i=0; i<n; i++) { for (usz i=0; i<n; i++) {
u64 x = xp[i]; u64 x = xp[i];
@ -208,7 +217,7 @@ static B scan_lt(B x, u64 p, usz ia) {
static B scan_plus(f64 r0, B x, u8 xe, usz ia) { static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
assert(xe!=el_bit && elNum(xe)); assert(xe!=el_bit && elNum(xe));
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr); B r; void* rp = m_tyarrc(&r, xe==el_f64? sizeof(f64) : sizeof(i32), x, xe==el_f64? t_f64arr : t_i32arr);
#if SINGELI #if SINGELI
switch(xe) { default:UD; switch(xe) { default:UD;
case el_i8: { if (!q_fi32(r0) || si_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; } case el_i8: { if (!q_fi32(r0) || si_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
@ -217,8 +226,8 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; } case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
} }
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; } cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; } cs_i16_f64: { decG(r); f64* rp; r = m_f64arrc(&rp, x); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; } cs_i32_f64: { decG(r); f64* rp; r = m_f64arrc(&rp, x); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
#else #else
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; } if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; } if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
@ -226,7 +235,7 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; } if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
base:; base:;
decG(r); decG(r);
f64* rp2; r = m_f64arrv(&rp2, ia); rp = rp2; f64* rp2; r = m_f64arrc(&rp2, x); rp = rp2;
x = toF64Any(x); x = toF64Any(x);
goto res_float; goto res_float;
#endif #endif
@ -234,10 +243,9 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
extern B scan_arith(B f, B w, B x, usz* xsh); // from cells.c extern B scan_arith(B f, B w, B x, usz* xsh); // from cells.c
B scan_c1(Md1D* d, B x) { B f = d->f; B scan_c1(Md1D* d, B x) { B f = d->f;
if (isAtm(x) || RNK(x)==0) thrM("𝔽`𝕩: 𝕩 cannot have rank 0"); if (isAtm(x)) { unit: thrM("𝔽`𝕩: 𝕩 cannot have rank 0"); }
ur xr = RNK(x); usz ia = IA(x); if (ia <= 1) { if (ia==1 && RNK(x)==0) goto unit; return x; }
usz ia = IA(x); usz n = *SH(x); if (n <= 1) return x;
if (*SH(x)<=1 || ia==0) return x;
if (RARE(!isFun(f))) { if (RARE(!isFun(f))) {
if (isMd(f)) thrM("Calling a modifier"); if (isMd(f)) thrM("Calling a modifier");
B xf = getFillR(x); B xf = getFillR(x);
@ -257,7 +265,50 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
Arr* r = TI(x,slice)(x, 0, csz); Arr* r = TI(x,slice)(x, 0, csz);
return C2(shape, s, taga(r)); return C2(shape, s, taga(r));
} }
if (!(xr==1 && xe<=el_f64)) goto base; if (xe > el_f64) goto base;
if (ia != n) { // csz != 1
#if SINGELI
usz csz = arr_csz(x);
i8 t = -1; bool neg = 0;
if (xe==el_bit) switch (rtid) {
CASE_N_OR: t=0; break;
CASE_N_AND: t=1; break;
case n_eq: neg=1; case n_ne: t=2; break;
}
if (t != -1) {
if (neg) x = bit_negate(x);
u64* rp; B r=m_bitarrc(&rp,x);
si_scan_bool_stride[t](bitany_ptr(x), rp, ia, csz);
if (neg) r = bit_negate(r);
decG(x); return r;
}
if (rtid==n_floor | rtid==n_ceil) {
// boolean was handled as CASE_N_AND
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
void* xp = tyany_ptr(x);
si_scan_stride_minmax[4*(rtid==n_ceil) + xe-el_i8](xp, rp, ia, csz);
decG(x); return r;
}
if (rtid==n_add) {
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xe=el_i8; }
restart:;
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
void* xp = tyany_ptr(x);
bool done = si_scan_stride_add[xe-el_i8](xp, rp, ia, csz);
if (!done) {
decG(r);
switch (++xe) { default: UD;
case el_i16: x = taga(cpyI16Arr(x)); break;
case el_i32: x = taga(cpyI32Arr(x)); break;
case el_f64: x = taga(cpyF64Arr(x)); break;
}
goto restart;
}
decG(x); return r;
}
#endif
goto base;
}
if (xe==el_bit) switch (rtid) { default: goto base; if (xe==el_bit) switch (rtid) { default: goto base;
case n_add: return scan_add_bool(x, ia); // + case n_add: return scan_add_bool(x, ia); // +
@ -278,7 +329,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
if (!elInt(xe)) goto base; if (!elInt(xe)) goto base;
f64 x0 = o2fG(IGetU(x,0)); f64 x0 = o2fG(IGetU(x,0));
if (!q_fbit(x0)) goto base; if (!q_fbit(x0)) goto base;
u64* rp; B r = m_bitarrv(&rp, ia); u64* rp; B r = m_bitarrc(&rp, x);
bool c = x0; bool c = x0;
rp[0] = c; rp[0] = c;
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; } if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
@ -289,7 +340,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); } if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); }
} }
base:; base:;
if (xr>1 && ia >= 6 * (u64)*SH(x) && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x)); if (ia!=n && ia >= 6 * (u64)n && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x));
SLOW2("𝕎` 𝕩", f, x); SLOW2("𝕎` 𝕩", f, x);
B xf = getFillR(x); B xf = getFillR(x);
@ -297,7 +348,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
SGet(x) SGet(x)
FC2 fc2 = c2fn(f); FC2 fc2 = c2fn(f);
if (xr==1) { if (ia == n) {
r.a[0] = Get(x,0); r.a[0] = Get(x,0);
for (usz i=1; i<ia; i++) r.a[i] = fc2(f, inc(r.a[i-1]), Get(x,i)); for (usz i=1; i<ia; i++) r.a[i] = fc2(f, inc(r.a[i-1]), Get(x,i));
} else { } else {
@ -327,7 +378,8 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
u8 rtid = RTID(f); u8 rtid = RTID(f);
if (rtid==n_rtack) { dec(w); return x; } if (rtid==n_rtack) { dec(w); return x; }
if (rtid==n_ltack) return C2(shape, C1(fne, x), w); if (rtid==n_ltack) return C2(shape, C1(fne, x), w);
if (!(xr==1 && elNum(xe) && xe<=el_f64)) goto base; if (!(elNum(xe) && xe<=el_f64)) goto base;
if (xr!=1 && *SH(x)!=ia) goto base;
if (!isF64(w)) goto base; if (!isF64(w)) goto base;
if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊ if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊
@ -350,7 +402,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitany_ptr(x)), ia); if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitany_ptr(x)), ia);
if (!wBit || !elInt(xe)) goto base; if (!wBit || !elInt(xe)) goto base;
bool c = o2bG(w); bool c = o2bG(w);
u64* rp; B r = m_bitarrv(&rp, ia); u64* rp; B r = m_bitarrc(&rp, x);
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; } if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; } if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; } if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }

View File

@ -96,7 +96,7 @@ local def ml_exec{i, iter, vars0, bulk, M} = {
# i0 - initial batch index; not used as begin because it's in a different scale compared to end # i0 - initial batch index; not used as begin because it's in a different scale compared to end
def for_masked{bulk, i0}{vars,begin==0,end,iter} = { def for_masked{bulk, i0}{vars,begin==0,end,iter} = {
l:u64 = end l:u64 = promote{u64, end}
m:u64 = l / bulk m:u64 = l / bulk
@for (i from i0 to m) ml_exec{i, iter, vars, bulk, mask_none} @for (i from i0 to m) ml_exec{i, iter, vars, bulk, mask_none}
@ -128,7 +128,7 @@ def for_masked_pos{bulk}{vars,begin==0,end:L,iter} = {
# end is scalar element count # end is scalar element count
# index given is a tuple of batch indexes to process # index given is a tuple of batch indexes to process
def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = { def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = {
l:u64 = end l:u64 = promote{u64, end}
m:u64 = l / bulk m:u64 = l / bulk
if (unr==1) { if (unr==1) {

View File

@ -98,6 +98,118 @@ def shift_first{c:V=[l]_, p:V} = {
else blend_first{c, rotate_right{p}} else blend_first{c, rotate_right{p}}
} }
# Strided scans
fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz) : Ret = {
def minvalue{(f64)} = -1/0; def maxvalue{(f64)} = 1/0
def id = match (op) {
{(min)} => maxvalue; {(max)} => minvalue
{(+)} => ({_}=>0)
}
x:= *T~~xv; r:= *T~~rv
# Architecture determination
# Use largest vector width with a full-width shuffle
def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'}
def I = if (hasarch{'AVX2'} and T>=i32) [8]i32 else [16]i8
def [il]IE = I; def selI = shuf{IE, ...}
def wT = width{T}
def f = wT/width{IE}
def vl = width{I}/wT
def V = [vl]T
if (has_shuf and l < vl) {
# Small stride: power-of-two shifts
def small{k} = {
iv:= iota{I}; j:= I**cast_i{IE,l*f}
spr:= I**il - j + iv
def inds = @collect (k) {
v:= iv - (j &~ I~~(iv<j))
spr = selI{spr, v}
js:= j; j+= j
if (not same{op, +}) selI{., v}
else if (same{IE,i8}) selI{., iv - js}
else { m:= V~~(iv >= js); {x} => selI{x, v} & m }
}
c:= V**id{T}
@for_masked{vl} (x in tup{V, x}, r in tup{V, r}, M in 'm' over ia) {
xs:= fold{{v, i} => op{i{v}, v}, x, inds}
r = c = op{shuf{IE, c, spr}, xs}
check_over{M, x, r} # For +, infers other argument as r-x
}
}
if (not (same{op,+} and V==[4]f64)) {
def max_k = lb{vl/2} # Divide by two from assuming l≥2
if (max_k<3 or l<4) small{max_k} else small{max_k-1} # l=2 and l=3 are the only cases needing the full max_k iterations; max_k<3 limits specialization to where it's significant
} else { # Non-associative!
c:= V**0
if (l==2) {
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
a:= c + shuf{x, 0,1,0,1}
c = a + shuf{x, 2,3,2,3}
r = blend{a, c, 0,0,1,1}
}
} else {
assert{l==3}
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
a:= shuf{c, 1,1,2,3} + blend{x, V**0, 0,1,1,1}
r = c = x + shuf{a, 1,2,3,0}
}
}
}
} else {
# Large stride: single shift, with saved register or memory
def op_chk{M, p, x} = { r:= op{p, x}; check_over{M, p, x, r}; r }
@for (r, x over l) r = x
if (has_shuf and l<256/(wT/8)) {
# Make sure to load the previous row data at the same alignment to not hit bad store-to-load forwarding
def [il]IE = I
q:= l%vl; fq:= cast_i{IE, q*f}
def rot = shuf{IE, ., (iota{I} - I**fq) & I**(il-1)}
bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv}
c:= V**id{T}
o:= l - q
if (l == 2*vl) { o = vl; bv = ~bv }
if (o == vl) {
p:= load{*V~~x}; store{*V~~r, 0, p}
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, M in 'm' over ia-o) {
p = rot{p}
r = op_chk{M, bl{c, p}, x}
c = p; p = r
}
} else {
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r}, M in 'm' over ia-o) {
q:= rot{p}
r = op_chk{M, bl{c, q}, x}
c = q
}
}
} else if (same{op, +} and T<=i32 and has_simd and (has_shuf or l>=vl)) {
def vl = arch_defvw/wT; def V = [vl]T
@for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r}, M in 'm' over ia-l) {
r = op_chk{M, p, x}
}
} else {
@for (r, x, p in r-l over _ from l to ia) r = op_chk{0, p, x}
}
}
1
}
def scan_stride_assoc{op, T} = scan_stride_assoc{op, T, void, {..._}=>{}}
def check_add_over{_, w:T, x:T, r:T} = { if ((w^r) & (x^r) < 0) return{0} }
def check_add_over{M, w:V=[_]E, x:V, r:V} = {
o:= (if (not hasarch{'X86_64'} or width{E}<=16) any_hom{M, subs{r,w} != x}
else any_top{M, (w^r) & (x^r)})
if (o) return{0}
}
def check_add_over{M, x, r} = check_add_over{M, r-x, x, r}
export_tab{'si_scan_stride_minmax',
flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}}
}
export_tab{'si_scan_stride_add', tup{
...each{scan_stride_assoc{+, ., u1, check_add_over}, tup{i8,i16,i32}},
scan_stride_assoc{+, f64, u1, {..._}=>{}}
}}
# xor scan # xor scan
def vec_prefix_byshift{op, sh} = { def vec_prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
@ -106,13 +218,13 @@ def vec_prefix_byshift{op, sh} = {
def scan_word_ne = prefix_byshift{^, <<} def scan_word_ne = prefix_byshift{^, <<}
def scan_words_ne = vec_prefix_byshift{^, <<} def scan_words_ne = vec_prefix_byshift{^, <<}
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = { fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:usz) : void = {
@for (x, r over nw) { @for (x, r over nw) {
r = c ^ scan_word_ne{x} r = c ^ scan_word_ne{x}
c = -(r>>63) # repeat sign bit c = -(r>>63) # repeat sign bit
} }
} }
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = { fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:usz) : void = {
def vl = arch_defvw / 64 def vl = arch_defvw / 64
def V = [vl]u64 def V = [vl]u64
c := V**c0 c := V**c0
@ -123,7 +235,7 @@ fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
c = broadcast_last{p} c = broadcast_last{p}
} }
} }
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = { fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:usz, mark:u64) : void = {
def V = [2]u64 def V = [2]u64
m := V**mark m := V**mark
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
@ -144,10 +256,10 @@ fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64
store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1} store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1}
} }
} }
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = { fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1)) clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
} }
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = { fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
def V = [8]u64 def V = [8]u64
def sse{a} = make{[2]u64, a, 0} def sse{a} = make{[2]u64, a, 0}
carry := sse{init} carry := sse{init}
@ -358,10 +470,11 @@ def loose_mask_gen{V=[vl]T, l} = { # Slow, for ≠` only
} }
def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'} def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'}
def loose_mask_gen{V=[vl](u64), l if has_vecshift} = { def loose_mask_gen{V=[vl](u64), l if has_vecshift} = {
l64 := promote{u64, l}
q := -make{V, 64*iota{vl}} # distance to next row boundary q := -make{V, 64*iota{vl}} # distance to next row boundary
def q_mod{} = { q+= V**l & -(q>>63) } def q_mod{} = { q+= V**l64 & -(q>>63) }
def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l, q} } def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l64, q} }
o:u64 = width{V}; while (o>l) { o-=l; q_mod{} } o:u64 = width{V}; while (o>l64) { o-=l64; q_mod{} }
{} => { {} => {
m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64 m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64
q-= V**o; q_mod{} q-= V**o; q_mod{}
@ -560,7 +673,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
c:u64 = 0 # carry c:u64 = 0 # carry
while (1) { while (1) {
i+= l; ii := iw; iw = cdiv{i, 64} i+= l; ii := iw; iw = cdiv{i, 64}
scan_neq{}(c, x+ii, r+ii, promote{u64,iw-ii}) scan_neq{}(c, x+ii, r+ii, iw-ii)
if (i == nl) return{} if (i == nl) return{}
s:= load{r, iw-1} s:= load{r, iw-1}
q := i%64 q := i%64
@ -619,3 +732,49 @@ export{'si_scan_rows_and', scan_rows_andor{0}}
export{'si_scan_rows_or', scan_rows_andor{1}} export{'si_scan_rows_or', scan_rows_andor{1}}
export{'si_scan_rows_ne', scan_rows_neq} export{'si_scan_rows_ne', scan_rows_neq}
export{'si_scan_rows_ltack', scan_rows_left} export{'si_scan_rows_ltack', scan_rows_left}
# Strided boolean scans
fn scan_stride_bool_assoc{op}(x:*u64, r:*u64, nl:usz, l:usz) : void = {
assert{l > 1}
def {flip,opf} = if (same{op, &}) tup{~,|} else tup{{x}=>x,op} # such that identity of opf is 0
nw:= cdiv{nl, 64}
if (l <= 64) {
if (same{op, ^} and hasarch{'PCLMUL'} and (l & (l-1)) == 0) {
clmul_scan_ne_any{}(*void~~x, *void~~r, 0, nw, aligned_spaced_mask{l})
return{}
}
c:u64 = 0 # carry l bits, no matter the alignment
@for (r, x over nw) {
c = opf{flip{x}, c >> (64-l)}
s:= l; while (s < 64) { c = opf{c, c<<s}; s += s }
r = flip{c}
}
} else if (l < 128) {
q:= l%64
c:= flip{u64~~0}
p:= load{x,0}
store{r,0, p}
@for (r, x over _ from 1 to nw) {
r = op{x, c>>(64-q) | p<<q}
c = p; p = r
}
} else {
fw:= cdiv{l, 64} # words in first cell
@for (r, x over fw) r = x
if (l%64 == 0) {
@for (r, x, p in r-fw over _ from fw to nw) r = op{x, p}
} else {
q:= l%64
c:= flip{u64~~0}
@for (r, x, p in r-(fw-1) over _ from fw-1 to nw) {
r = op{x, c>>(64-q) | p<<q}
c = p
}
}
}
}
export_tab{
'si_scan_bool_stride',
each{scan_stride_bool_assoc, tup{|, &, ^}}
}

2
test/cases/fuzz/scan.bqn Normal file
View File

@ -0,0 +1,2 @@
# tests scan on a matrix checking for overflow in tail elements
81632 {e𝕊nw: m2e-2 (m+(¯1,w|n) w(n+w+1)m) +` •internal.Squeeze w (wm)(n0)m} 1+20065

View File

@ -498,6 +498,9 @@ a←↕2 ⋄ ! "e" ≡ (↕10){b←a‿a‿a‿a‿a‿a‿a‿a‿a‿a ⋄
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿2‿3 ≡ ≢𝕩)" % (221)+`323 !"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿2‿3 ≡ ≢𝕩)" % (221)+`323
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿3‿2 ≡ ≢𝕩)" % (221)+`332 !"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿3‿2 ≡ ≢𝕩)" % (221)+`332
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (⟨⟩ ≡ ≢𝕨, 3‿3 ≡ ≢𝕩)" % 2+`33 !"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (⟨⟩ ≡ ≢𝕨, 3‿3 ≡ ≢𝕩)" % 2+`33
%USE eqvar %USE k +,-,×,÷,,,,,¬,,,<,>,,=,,,, {f𝕊arr: ! ( F _k`_k arr) _k F` _eqvar arr} {𝕩˘¨𝕩} ˘¨ 1, 101
%USE eqvar %USE k +,-,×,÷,,,,,¬,,,<,>,,=,,,, {f𝕊arr: ! (1 F _k`_k arr) _k (1¨𝕩)(F`) _eqvar arr} {𝕩˘¨𝕩} ˘¨ 1, 101
# ´ # ´
!"𝔽´𝕩: 𝕩 must be a list (⟨⟩ ≡ ≢𝕩)" % +´0 !"𝔽´𝕩: 𝕩 must be a list (⟨⟩ ≡ ≢𝕩)" % +´0