Merge pull request #135 from mlochbaum/scan
Optimize common high-rank scans
This commit is contained in:
commit
bb3bb1b1d3
@ -1,22 +1,31 @@
|
||||
// Scan (`)
|
||||
// Empty 𝕩, and length 1 if no 𝕨: return 𝕩
|
||||
// Generic operand:
|
||||
// Generic argument:
|
||||
// Constant: copy
|
||||
// ⊢ identity, ⊣ reshape 𝕨 or first cell
|
||||
// Boolean operand, rank 1:
|
||||
// Boolean argument, stride 1:
|
||||
// + AVX2 expansion (SHOULD have better generic, add SSE, NEON)
|
||||
// ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search)
|
||||
// ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul)
|
||||
// < SWAR
|
||||
// =≤≥>- in terms of ≠<∨∧+ with adjustments
|
||||
// Arithmetic operand, rank 1:
|
||||
// Numeric argument, stride 1:
|
||||
// ⌈⌊ Scalar, SIMD in log(vector width) steps
|
||||
// Check in 6-vector blocks to quickly write result if constant
|
||||
// + Overflow-checked scalar or AVX2
|
||||
// Ad-hoc boolean-valued handling for ≠∨
|
||||
// SHOULD extend rank 1 special cases to cell bound 1
|
||||
// Higher-rank arithmetic, non-tiny cells: apply operand cell-wise
|
||||
// SHOULD have dedicated high-rank scan optimizations
|
||||
// Higher-rank arithmetic:
|
||||
// Boolean ≠∨∧ and synonyms: SWAR; ⌊⌈+: SIMD with shuffle/permute
|
||||
// Stride <word/vector: power-of-two (times stride) shifts
|
||||
// COULD vectorize small-stride boolean scans
|
||||
// ≠, divisor of 64: CLMUL
|
||||
// Stride 1 to 2 words: result in register instead of re-reading
|
||||
// Read and write at same alignment unless stride is large
|
||||
// Large-stride cases auto-vectorize (except + overflow check)
|
||||
// Overflow check for +, widen and retry on failure
|
||||
// =` as ≠`⌾¬
|
||||
// SHOULD optimize high-rank dyadic scan for recognized operands
|
||||
// Other arithmetic, non-tiny cells: apply operand cell-wise
|
||||
|
||||
// Scan with rank (`˘ or `⎉k)
|
||||
// SHOULD optimize dyadic scan with rank
|
||||
@ -70,7 +79,7 @@ B mul_c2(B, B, B);
|
||||
|
||||
B scan_ne(B x, u64 p, u64 ia) { // consumes x
|
||||
u64* xp = bitany_ptr(x);
|
||||
u64* rp; B r=m_bitarrv(&rp,ia);
|
||||
u64* rp; B r=m_bitarrc(&rp,x);
|
||||
#if SINGELI
|
||||
si_scan_ne(p, xp, rp, BIT_N(ia));
|
||||
#if USE_VALGRIND
|
||||
@ -97,14 +106,14 @@ B scan_eq(B x, u64 ia) { // consumes x
|
||||
|
||||
static B scan_or(B x, u64 ia) { // consumes x
|
||||
u64* xp = bitany_ptr(x);
|
||||
u64* rp; B r=m_bitarrv(&rp,ia);
|
||||
u64* rp; B r=m_bitarrc(&rp,x);
|
||||
usz n=BIT_N(ia); u64 xi; usz i=0;
|
||||
while (i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0;
|
||||
decG(x); return FL_SET(r, fl_asc|fl_squoze);
|
||||
}
|
||||
static B scan_and(B x, u64 ia) { // consumes x
|
||||
u64* xp = bitany_ptr(x);
|
||||
u64* rp; B r=m_bitarrv(&rp,ia);
|
||||
u64* rp; B r=m_bitarrc(&rp,x);
|
||||
usz n=BIT_N(ia); u64 xi; usz i=0;
|
||||
while (i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL;
|
||||
decG(x); return FL_SET(r, fl_dsc|fl_squoze);
|
||||
@ -130,7 +139,7 @@ B scan_add_bool(B x, u64 ia) { // consumes x
|
||||
decG(ones);
|
||||
r = mut_fv(r0);
|
||||
} else {
|
||||
void* rp = m_tyarrv(&r, elWidth(re), ia, el2t(re));
|
||||
void* rp = m_tyarrc(&r, elWidth(re), x, el2t(re));
|
||||
#define SUM_BITWISE(T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } }
|
||||
#if SINGELI
|
||||
#define SUM(W,T) si_bcs##W(xp, rp, ia);
|
||||
@ -156,7 +165,7 @@ B scan_add_bool(B x, u64 ia) { // consumes x
|
||||
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
|
||||
#endif
|
||||
#define MM_CASE(T,N,C,I) \
|
||||
case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,I); break; }
|
||||
case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,I); break; }
|
||||
#define MINMAX(NAME,C,INIT,BIT,ORD) \
|
||||
B r; switch (xe) { default:UD; \
|
||||
case el_bit: return scan_##BIT(x, ia); \
|
||||
@ -174,7 +183,7 @@ B scan_max_num(B x, u8 xe, u64 ia) { MINMAX(max,>,MIN,or ,asc) }
|
||||
#define MM2_ICASE(T,N,C,I) \
|
||||
case el_##T : { \
|
||||
if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \
|
||||
T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,wv); \
|
||||
T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,wv); \
|
||||
break; }
|
||||
#define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \
|
||||
i32 wv=0; if (q_i32(w)) { wv=o2fG(w); } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \
|
||||
@ -195,7 +204,7 @@ SHOULD_INLINE B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0
|
||||
|
||||
static B scan_lt(B x, u64 p, usz ia) {
|
||||
u64* xp = bitany_ptr(x);
|
||||
u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia);
|
||||
u64* rp; B r=m_bitarrc(&rp,x); usz n=BIT_N(ia);
|
||||
u64 m = 0x5555555555555555;
|
||||
for (usz i=0; i<n; i++) {
|
||||
u64 x = xp[i];
|
||||
@ -208,7 +217,7 @@ static B scan_lt(B x, u64 p, usz ia) {
|
||||
|
||||
static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
||||
assert(xe!=el_bit && elNum(xe));
|
||||
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
|
||||
B r; void* rp = m_tyarrc(&r, xe==el_f64? sizeof(f64) : sizeof(i32), x, xe==el_f64? t_f64arr : t_i32arr);
|
||||
#if SINGELI
|
||||
switch(xe) { default:UD;
|
||||
case el_i8: { if (!q_fi32(r0) || si_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
|
||||
@ -217,8 +226,8 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
||||
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||
}
|
||||
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
|
||||
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrc(&rp, x); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrc(&rp, x); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||
#else
|
||||
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
@ -226,7 +235,7 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
||||
if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||
base:;
|
||||
decG(r);
|
||||
f64* rp2; r = m_f64arrv(&rp2, ia); rp = rp2;
|
||||
f64* rp2; r = m_f64arrc(&rp2, x); rp = rp2;
|
||||
x = toF64Any(x);
|
||||
goto res_float;
|
||||
#endif
|
||||
@ -234,10 +243,9 @@ static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
||||
|
||||
extern B scan_arith(B f, B w, B x, usz* xsh); // from cells.c
|
||||
B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
if (isAtm(x) || RNK(x)==0) thrM("𝔽`𝕩: 𝕩 cannot have rank 0");
|
||||
ur xr = RNK(x);
|
||||
usz ia = IA(x);
|
||||
if (*SH(x)<=1 || ia==0) return x;
|
||||
if (isAtm(x)) { unit: thrM("𝔽`𝕩: 𝕩 cannot have rank 0"); }
|
||||
usz ia = IA(x); if (ia <= 1) { if (ia==1 && RNK(x)==0) goto unit; return x; }
|
||||
usz n = *SH(x); if (n <= 1) return x;
|
||||
if (RARE(!isFun(f))) {
|
||||
if (isMd(f)) thrM("Calling a modifier");
|
||||
B xf = getFillR(x);
|
||||
@ -257,7 +265,50 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
Arr* r = TI(x,slice)(x, 0, csz);
|
||||
return C2(shape, s, taga(r));
|
||||
}
|
||||
if (!(xr==1 && xe<=el_f64)) goto base;
|
||||
if (xe > el_f64) goto base;
|
||||
if (ia != n) { // csz != 1
|
||||
#if SINGELI
|
||||
usz csz = arr_csz(x);
|
||||
i8 t = -1; bool neg = 0;
|
||||
if (xe==el_bit) switch (rtid) {
|
||||
CASE_N_OR: t=0; break;
|
||||
CASE_N_AND: t=1; break;
|
||||
case n_eq: neg=1; case n_ne: t=2; break;
|
||||
}
|
||||
if (t != -1) {
|
||||
if (neg) x = bit_negate(x);
|
||||
u64* rp; B r=m_bitarrc(&rp,x);
|
||||
si_scan_bool_stride[t](bitany_ptr(x), rp, ia, csz);
|
||||
if (neg) r = bit_negate(r);
|
||||
decG(x); return r;
|
||||
}
|
||||
if (rtid==n_floor | rtid==n_ceil) {
|
||||
// boolean was handled as CASE_N_AND
|
||||
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
|
||||
void* xp = tyany_ptr(x);
|
||||
si_scan_stride_minmax[4*(rtid==n_ceil) + xe-el_i8](xp, rp, ia, csz);
|
||||
decG(x); return r;
|
||||
}
|
||||
if (rtid==n_add) {
|
||||
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xe=el_i8; }
|
||||
restart:;
|
||||
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
|
||||
void* xp = tyany_ptr(x);
|
||||
bool done = si_scan_stride_add[xe-el_i8](xp, rp, ia, csz);
|
||||
if (!done) {
|
||||
decG(r);
|
||||
switch (++xe) { default: UD;
|
||||
case el_i16: x = taga(cpyI16Arr(x)); break;
|
||||
case el_i32: x = taga(cpyI32Arr(x)); break;
|
||||
case el_f64: x = taga(cpyF64Arr(x)); break;
|
||||
}
|
||||
goto restart;
|
||||
}
|
||||
decG(x); return r;
|
||||
}
|
||||
#endif
|
||||
goto base;
|
||||
}
|
||||
|
||||
if (xe==el_bit) switch (rtid) { default: goto base;
|
||||
case n_add: return scan_add_bool(x, ia); // +
|
||||
@ -278,7 +329,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
if (!elInt(xe)) goto base;
|
||||
f64 x0 = o2fG(IGetU(x,0));
|
||||
if (!q_fbit(x0)) goto base;
|
||||
u64* rp; B r = m_bitarrv(&rp, ia);
|
||||
u64* rp; B r = m_bitarrc(&rp, x);
|
||||
bool c = x0;
|
||||
rp[0] = c;
|
||||
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||
@ -289,7 +340,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); }
|
||||
}
|
||||
base:;
|
||||
if (xr>1 && ia >= 6 * (u64)*SH(x) && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x));
|
||||
if (ia!=n && ia >= 6 * (u64)n && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x));
|
||||
SLOW2("𝕎` 𝕩", f, x);
|
||||
B xf = getFillR(x);
|
||||
|
||||
@ -297,7 +348,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
SGet(x)
|
||||
FC2 fc2 = c2fn(f);
|
||||
|
||||
if (xr==1) {
|
||||
if (ia == n) {
|
||||
r.a[0] = Get(x,0);
|
||||
for (usz i=1; i<ia; i++) r.a[i] = fc2(f, inc(r.a[i-1]), Get(x,i));
|
||||
} else {
|
||||
@ -327,7 +378,8 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
|
||||
u8 rtid = RTID(f);
|
||||
if (rtid==n_rtack) { dec(w); return x; }
|
||||
if (rtid==n_ltack) return C2(shape, C1(fne, x), w);
|
||||
if (!(xr==1 && elNum(xe) && xe<=el_f64)) goto base;
|
||||
if (!(elNum(xe) && xe<=el_f64)) goto base;
|
||||
if (xr!=1 && *SH(x)!=ia) goto base;
|
||||
if (!isF64(w)) goto base;
|
||||
|
||||
if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊
|
||||
@ -350,7 +402,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
|
||||
if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitany_ptr(x)), ia);
|
||||
if (!wBit || !elInt(xe)) goto base;
|
||||
bool c = o2bG(w);
|
||||
u64* rp; B r = m_bitarrv(&rp, ia);
|
||||
u64* rp; B r = m_bitarrc(&rp, x);
|
||||
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
|
||||
|
||||
@ -96,7 +96,7 @@ local def ml_exec{i, iter, vars0, bulk, M} = {
|
||||
|
||||
# i0 - initial batch index; not used as begin because it's in a different scale compared to end
|
||||
def for_masked{bulk, i0}{vars,begin==0,end,iter} = {
|
||||
l:u64 = end
|
||||
l:u64 = promote{u64, end}
|
||||
|
||||
m:u64 = l / bulk
|
||||
@for (i from i0 to m) ml_exec{i, iter, vars, bulk, mask_none}
|
||||
@ -128,7 +128,7 @@ def for_masked_pos{bulk}{vars,begin==0,end:L,iter} = {
|
||||
# end is scalar element count
|
||||
# index given is a tuple of batch indexes to process
|
||||
def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = {
|
||||
l:u64 = end
|
||||
l:u64 = promote{u64, end}
|
||||
|
||||
m:u64 = l / bulk
|
||||
if (unr==1) {
|
||||
|
||||
@ -98,6 +98,118 @@ def shift_first{c:V=[l]_, p:V} = {
|
||||
else blend_first{c, rotate_right{p}}
|
||||
}
|
||||
|
||||
|
||||
# Strided scans
|
||||
fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz) : Ret = {
|
||||
def minvalue{(f64)} = -1/0; def maxvalue{(f64)} = 1/0
|
||||
def id = match (op) {
|
||||
{(min)} => maxvalue; {(max)} => minvalue
|
||||
{(+)} => ({_}=>0)
|
||||
}
|
||||
x:= *T~~xv; r:= *T~~rv
|
||||
# Architecture determination
|
||||
# Use largest vector width with a full-width shuffle
|
||||
def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'}
|
||||
def I = if (hasarch{'AVX2'} and T>=i32) [8]i32 else [16]i8
|
||||
def [il]IE = I; def selI = shuf{IE, ...}
|
||||
def wT = width{T}
|
||||
def f = wT/width{IE}
|
||||
def vl = width{I}/wT
|
||||
def V = [vl]T
|
||||
if (has_shuf and l < vl) {
|
||||
# Small stride: power-of-two shifts
|
||||
def small{k} = {
|
||||
iv:= iota{I}; j:= I**cast_i{IE,l*f}
|
||||
spr:= I**il - j + iv
|
||||
def inds = @collect (k) {
|
||||
v:= iv - (j &~ I~~(iv<j))
|
||||
spr = selI{spr, v}
|
||||
js:= j; j+= j
|
||||
if (not same{op, +}) selI{., v}
|
||||
else if (same{IE,i8}) selI{., iv - js}
|
||||
else { m:= V~~(iv >= js); {x} => selI{x, v} & m }
|
||||
}
|
||||
c:= V**id{T}
|
||||
@for_masked{vl} (x in tup{V, x}, r in tup{V, r}, M in 'm' over ia) {
|
||||
xs:= fold{{v, i} => op{i{v}, v}, x, inds}
|
||||
r = c = op{shuf{IE, c, spr}, xs}
|
||||
check_over{M, x, r} # For +, infers other argument as r-x
|
||||
}
|
||||
}
|
||||
if (not (same{op,+} and V==[4]f64)) {
|
||||
def max_k = lb{vl/2} # Divide by two from assuming l≥2
|
||||
if (max_k<3 or l<4) small{max_k} else small{max_k-1} # l=2 and l=3 are the only cases needing the full max_k iterations; max_k<3 limits specialization to where it's significant
|
||||
} else { # Non-associative!
|
||||
c:= V**0
|
||||
if (l==2) {
|
||||
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||
a:= c + shuf{x, 0,1,0,1}
|
||||
c = a + shuf{x, 2,3,2,3}
|
||||
r = blend{a, c, 0,0,1,1}
|
||||
}
|
||||
} else {
|
||||
assert{l==3}
|
||||
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||
a:= shuf{c, 1,1,2,3} + blend{x, V**0, 0,1,1,1}
|
||||
r = c = x + shuf{a, 1,2,3,0}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
# Large stride: single shift, with saved register or memory
|
||||
def op_chk{M, p, x} = { r:= op{p, x}; check_over{M, p, x, r}; r }
|
||||
@for (r, x over l) r = x
|
||||
if (has_shuf and l<256/(wT/8)) {
|
||||
# Make sure to load the previous row data at the same alignment to not hit bad store-to-load forwarding
|
||||
def [il]IE = I
|
||||
q:= l%vl; fq:= cast_i{IE, q*f}
|
||||
def rot = shuf{IE, ., (iota{I} - I**fq) & I**(il-1)}
|
||||
bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv}
|
||||
c:= V**id{T}
|
||||
o:= l - q
|
||||
if (l == 2*vl) { o = vl; bv = ~bv }
|
||||
if (o == vl) {
|
||||
p:= load{*V~~x}; store{*V~~r, 0, p}
|
||||
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, M in 'm' over ia-o) {
|
||||
p = rot{p}
|
||||
r = op_chk{M, bl{c, p}, x}
|
||||
c = p; p = r
|
||||
}
|
||||
} else {
|
||||
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r}, M in 'm' over ia-o) {
|
||||
q:= rot{p}
|
||||
r = op_chk{M, bl{c, q}, x}
|
||||
c = q
|
||||
}
|
||||
}
|
||||
} else if (same{op, +} and T<=i32 and has_simd and (has_shuf or l>=vl)) {
|
||||
def vl = arch_defvw/wT; def V = [vl]T
|
||||
@for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r}, M in 'm' over ia-l) {
|
||||
r = op_chk{M, p, x}
|
||||
}
|
||||
} else {
|
||||
@for (r, x, p in r-l over _ from l to ia) r = op_chk{0, p, x}
|
||||
}
|
||||
}
|
||||
1
|
||||
}
|
||||
def scan_stride_assoc{op, T} = scan_stride_assoc{op, T, void, {..._}=>{}}
|
||||
def check_add_over{_, w:T, x:T, r:T} = { if ((w^r) & (x^r) < 0) return{0} }
|
||||
def check_add_over{M, w:V=[_]E, x:V, r:V} = {
|
||||
o:= (if (not hasarch{'X86_64'} or width{E}<=16) any_hom{M, subs{r,w} != x}
|
||||
else any_top{M, (w^r) & (x^r)})
|
||||
if (o) return{0}
|
||||
}
|
||||
def check_add_over{M, x, r} = check_add_over{M, r-x, x, r}
|
||||
export_tab{'si_scan_stride_minmax',
|
||||
flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}}
|
||||
}
|
||||
export_tab{'si_scan_stride_add', tup{
|
||||
...each{scan_stride_assoc{+, ., u1, check_add_over}, tup{i8,i16,i32}},
|
||||
scan_stride_assoc{+, f64, u1, {..._}=>{}}
|
||||
}}
|
||||
|
||||
|
||||
# xor scan
|
||||
def vec_prefix_byshift{op, sh} = {
|
||||
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
|
||||
@ -106,13 +218,13 @@ def vec_prefix_byshift{op, sh} = {
|
||||
def scan_word_ne = prefix_byshift{^, <<}
|
||||
def scan_words_ne = vec_prefix_byshift{^, <<}
|
||||
|
||||
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = {
|
||||
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||
@for (x, r over nw) {
|
||||
r = c ^ scan_word_ne{x}
|
||||
c = -(r>>63) # repeat sign bit
|
||||
}
|
||||
}
|
||||
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
|
||||
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||
def vl = arch_defvw / 64
|
||||
def V = [vl]u64
|
||||
c := V**c0
|
||||
@ -123,7 +235,7 @@ fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
|
||||
c = broadcast_last{p}
|
||||
}
|
||||
}
|
||||
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
|
||||
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:usz, mark:u64) : void = {
|
||||
def V = [2]u64
|
||||
m := V**mark
|
||||
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
|
||||
@ -144,10 +256,10 @@ fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64
|
||||
store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1}
|
||||
}
|
||||
}
|
||||
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
|
||||
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
|
||||
}
|
||||
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
|
||||
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
|
||||
def V = [8]u64
|
||||
def sse{a} = make{[2]u64, a, 0}
|
||||
carry := sse{init}
|
||||
@ -358,10 +470,11 @@ def loose_mask_gen{V=[vl]T, l} = { # Slow, for ≠` only
|
||||
}
|
||||
def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'}
|
||||
def loose_mask_gen{V=[vl](u64), l if has_vecshift} = {
|
||||
l64 := promote{u64, l}
|
||||
q := -make{V, 64*iota{vl}} # distance to next row boundary
|
||||
def q_mod{} = { q+= V**l & -(q>>63) }
|
||||
def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l, q} }
|
||||
o:u64 = width{V}; while (o>l) { o-=l; q_mod{} }
|
||||
def q_mod{} = { q+= V**l64 & -(q>>63) }
|
||||
def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l64, q} }
|
||||
o:u64 = width{V}; while (o>l64) { o-=l64; q_mod{} }
|
||||
{} => {
|
||||
m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64
|
||||
q-= V**o; q_mod{}
|
||||
@ -560,7 +673,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
|
||||
c:u64 = 0 # carry
|
||||
while (1) {
|
||||
i+= l; ii := iw; iw = cdiv{i, 64}
|
||||
scan_neq{}(c, x+ii, r+ii, promote{u64,iw-ii})
|
||||
scan_neq{}(c, x+ii, r+ii, iw-ii)
|
||||
if (i == nl) return{}
|
||||
s:= load{r, iw-1}
|
||||
q := i%64
|
||||
@ -619,3 +732,49 @@ export{'si_scan_rows_and', scan_rows_andor{0}}
|
||||
export{'si_scan_rows_or', scan_rows_andor{1}}
|
||||
export{'si_scan_rows_ne', scan_rows_neq}
|
||||
export{'si_scan_rows_ltack', scan_rows_left}
|
||||
|
||||
|
||||
# Strided boolean scans
|
||||
fn scan_stride_bool_assoc{op}(x:*u64, r:*u64, nl:usz, l:usz) : void = {
|
||||
assert{l > 1}
|
||||
def {flip,opf} = if (same{op, &}) tup{~,|} else tup{{x}=>x,op} # such that identity of opf is 0
|
||||
nw:= cdiv{nl, 64}
|
||||
if (l <= 64) {
|
||||
if (same{op, ^} and hasarch{'PCLMUL'} and (l & (l-1)) == 0) {
|
||||
clmul_scan_ne_any{}(*void~~x, *void~~r, 0, nw, aligned_spaced_mask{l})
|
||||
return{}
|
||||
}
|
||||
c:u64 = 0 # carry l bits, no matter the alignment
|
||||
@for (r, x over nw) {
|
||||
c = opf{flip{x}, c >> (64-l)}
|
||||
s:= l; while (s < 64) { c = opf{c, c<<s}; s += s }
|
||||
r = flip{c}
|
||||
}
|
||||
} else if (l < 128) {
|
||||
q:= l%64
|
||||
c:= flip{u64~~0}
|
||||
p:= load{x,0}
|
||||
store{r,0, p}
|
||||
@for (r, x over _ from 1 to nw) {
|
||||
r = op{x, c>>(64-q) | p<<q}
|
||||
c = p; p = r
|
||||
}
|
||||
} else {
|
||||
fw:= cdiv{l, 64} # words in first cell
|
||||
@for (r, x over fw) r = x
|
||||
if (l%64 == 0) {
|
||||
@for (r, x, p in r-fw over _ from fw to nw) r = op{x, p}
|
||||
} else {
|
||||
q:= l%64
|
||||
c:= flip{u64~~0}
|
||||
@for (r, x, p in r-(fw-1) over _ from fw-1 to nw) {
|
||||
r = op{x, c>>(64-q) | p<<q}
|
||||
c = p
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
export_tab{
|
||||
'si_scan_bool_stride',
|
||||
each{scan_stride_bool_assoc, tup{|, &, ^}}
|
||||
}
|
||||
|
||||
2
test/cases/fuzz/scan.bqn
Normal file
2
test/cases/fuzz/scan.bqn
Normal file
@ -0,0 +1,2 @@
|
||||
# tests scan on a matrix checking for overflow in tail elements
|
||||
8‿16‿32 {e𝕊n‿w: m←2⋆e-2 ⋄ (m⊸+⌾(⟨¯1,w|n⟩⊸⊑) ⌽‿w⥊(n+w+1)⥊m) ≡ +` •internal.Squeeze ↑‿w⥊ (w⥊m)∾(n⥊0)∾m}⌜ 1+↕200‿65
|
||||
@ -498,6 +498,9 @@ a←↕2 ⋄ ! "e" ≡ (↕10){b←a‿a‿a‿a‿a‿a‿a‿a‿a‿a ⋄
|
||||
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿2‿3 ≡ ≢𝕩)" % (2‿2⥊1)+`↕3‿2‿3
|
||||
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (2‿2 ≡ ≢𝕨, 3‿3‿2 ≡ ≢𝕩)" % (2‿2⥊1)+`↕3‿3‿2
|
||||
!"𝕨𝔽`𝕩: Shape of 𝕨 must match the cell of 𝕩 (⟨⟩ ≡ ≢𝕨, 3‿3 ≡ ≢𝕩)" % 2+`↕3‿3
|
||||
%USE eqvar ⋄ %USE k ⋄ ⟨+,-,×,÷,⋆,√,⌊,⌈,¬,∧,∨,<,>,≠,=,≤,≥,⊣,⊢⟩ {f𝕊arr: ! ( F _k`_k ⥊arr) ≡ ⥊_k F` _eqvar arr}⌜ {𝕩∾≍˘¨𝕩} ≍˘¨ ⟨⋈1, 1‿0‿1⟩
|
||||
%USE eqvar ⋄ %USE k ⋄ ⟨+,-,×,÷,⋆,√,⌊,⌈,¬,∧,∨,<,>,≠,=,≤,≥,⊣,⊢⟩ {f𝕊arr: ! (1 F _k`_k ⥊arr) ≡ ⥊_k (1¨⊏𝕩)⊸(F`) _eqvar arr}⌜ {𝕩∾≍˘¨𝕩} ≍˘¨ ⟨⋈1, 1‿0‿1⟩
|
||||
|
||||
|
||||
# ´
|
||||
!"𝔽´𝕩: 𝕩 must be a list (⟨⟩ ≡ ≢𝕩)" % +´0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user