diff --git a/src/builtins/scan.c b/src/builtins/scan.c index 128aa726..cbc3f74d 100644 --- a/src/builtins/scan.c +++ b/src/builtins/scan.c @@ -1,22 +1,31 @@ // Scan (`) // Empty 𝕩, and length 1 if no 𝕨: return 𝕩 -// Generic operand: +// Generic argument: // Constant: copy // ⊢ identity, ⊣ reshape 𝕨 or first cell -// Boolean operand, rank 1: +// Boolean argument, stride 1: // + AVX2 expansion (SHOULD have better generic, add SSE, NEON) // ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search) // ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul) // < SWAR // =≤≥>- in terms of ≠<∨∧+ with adjustments -// Arithmetic operand, rank 1: +// Numeric argument, stride 1: // ⌈⌊ Scalar, SIMD in log(vector width) steps // Check in 6-vector blocks to quickly write result if constant // + Overflow-checked scalar or AVX2 // Ad-hoc boolean-valued handling for ≠∨ -// SHOULD extend rank 1 special cases to cell bound 1 -// Higher-rank arithmetic, non-tiny cells: apply operand cell-wise -// SHOULD have dedicated high-rank scan optimizations +// Higher-rank arithmetic: +// Boolean ≠∨∧ and synonyms: SWAR; ⌊⌈+: SIMD with shuffle/permute +// Stride ,MIN,or ,asc) } #define MM2_ICASE(T,N,C,I) \ case el_##T : { \ if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \ - T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,wv); \ + T* xp=T##any_ptr(x); T* rp; r=m_##T##arrc(&rp, x); MINMAX_SCAN(T,N,C,wv); \ break; } #define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \ i32 wv=0; if (q_i32(w)) { wv=o2fG(w); } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \ @@ -195,7 +204,7 @@ SHOULD_INLINE B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0 static B scan_lt(B x, u64 p, usz ia) { u64* xp = bitany_ptr(x); - u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); + u64* rp; B r=m_bitarrc(&rp,x); usz n=BIT_N(ia); u64 m = 0x5555555555555555; for (usz i=0; if; - if (isAtm(x) || RNK(x)==0) thrM("𝔽`𝕩: 𝕩 cannot have rank 0"); - ur xr = RNK(x); - usz ia = IA(x); - if (*SH(x)<=1 || ia==0) return x; + if (isAtm(x)) { unit: thrM("𝔽`𝕩: 𝕩 cannot have rank 0"); } + usz ia = IA(x); if (ia <= 1) { if (ia==1 && RNK(x)==0) goto unit; return x; } + usz n = *SH(x); if (n <= 1) return x; if (RARE(!isFun(f))) { if (isMd(f)) thrM("Calling a modifier"); B xf = getFillR(x); @@ -257,7 +265,50 @@ B scan_c1(Md1D* d, B x) { B f = d->f; Arr* r = TI(x,slice)(x, 0, csz); return C2(shape, s, taga(r)); } - if (!(xr==1 && xe<=el_f64)) goto base; + if (xe > el_f64) goto base; + if (ia != n) { // csz != 1 + #if SINGELI + usz csz = arr_csz(x); + i8 t = -1; bool neg = 0; + if (xe==el_bit) switch (rtid) { + CASE_N_OR: t=0; break; + CASE_N_AND: t=1; break; + case n_eq: neg=1; case n_ne: t=2; break; + } + if (t != -1) { + if (neg) x = bit_negate(x); + u64* rp; B r=m_bitarrc(&rp,x); + si_scan_bool_stride[t](bitany_ptr(x), rp, ia, csz); + if (neg) r = bit_negate(r); + decG(x); return r; + } + if (rtid==n_floor | rtid==n_ceil) { + // boolean was handled as CASE_N_AND + B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe)); + void* xp = tyany_ptr(x); + si_scan_stride_minmax[4*(rtid==n_ceil) + xe-el_i8](xp, rp, ia, csz); + decG(x); return r; + } + if (rtid==n_add) { + if (xe==el_bit) { x = taga(cpyI8Arr(x)); xe=el_i8; } + restart:; + B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe)); + void* xp = tyany_ptr(x); + bool done = si_scan_stride_add[xe-el_i8](xp, rp, ia, csz); + if (!done) { + decG(r); + switch (++xe) { default: UD; + case el_i16: x = taga(cpyI16Arr(x)); break; + case el_i32: x = taga(cpyI32Arr(x)); break; + case el_f64: x = taga(cpyF64Arr(x)); break; + } + goto restart; + } + decG(x); return r; + } + #endif + goto base; + } if (xe==el_bit) switch (rtid) { default: goto base; case n_add: return scan_add_bool(x, ia); // + @@ -278,7 +329,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f; if (!elInt(xe)) goto base; f64 x0 = o2fG(IGetU(x,0)); if (!q_fbit(x0)) goto base; - u64* rp; B r = m_bitarrv(&rp, ia); + u64* rp; B r = m_bitarrc(&rp, x); bool c = x0; rp[0] = c; if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; if; if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); } } base:; - if (xr>1 && ia >= 6 * (u64)*SH(x) && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x)); + if (ia!=n && ia >= 6 * (u64)n && isPervasiveDy(f)) return scan_arith(f, m_f64(0), x, SH(x)); SLOW2("𝕎` 𝕩", f, x); B xf = getFillR(x); @@ -297,7 +348,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f; SGet(x) FC2 fc2 = c2fn(f); - if (xr==1) { + if (ia == n) { r.a[0] = Get(x,0); for (usz i=1; if; u8 rtid = RTID(f); if (rtid==n_rtack) { dec(w); return x; } if (rtid==n_ltack) return C2(shape, C1(fne, x), w); - if (!(xr==1 && elNum(xe) && xe<=el_f64)) goto base; + if (!(elNum(xe) && xe<=el_f64)) goto base; + if (xr!=1 && *SH(x)!=ia) goto base; if (!isF64(w)) goto base; if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊ @@ -350,7 +402,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f; if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitany_ptr(x)), ia); if (!wBit || !elInt(xe)) goto base; bool c = o2bG(w); - u64* rp; B r = m_bitarrv(&rp, ia); + u64* rp; B r = m_bitarrc(&rp, x); if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i maxvalue; {(max)} => minvalue + {(+)} => ({_}=>0) + } + x:= *T~~xv; r:= *T~~rv + # Architecture determination + # Use largest vector width with a full-width shuffle + def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'} + def I = if (hasarch{'AVX2'} and T>=i32) [8]i32 else [16]i8 + def [il]IE = I; def selI = shuf{IE, ...} + def wT = width{T} + def f = wT/width{IE} + def vl = width{I}/wT + def V = [vl]T + if (has_shuf and l < vl) { + # Small stride: power-of-two shifts + def small{k} = { + iv:= iota{I}; j:= I**cast_i{IE,l*f} + spr:= I**il - j + iv + def inds = @collect (k) { + v:= iv - (j &~ I~~(iv= js); {x} => selI{x, v} & m } + } + c:= V**id{T} + @for_masked{vl} (x in tup{V, x}, r in tup{V, r}, M in 'm' over ia) { + xs:= fold{{v, i} => op{i{v}, v}, x, inds} + r = c = op{shuf{IE, c, spr}, xs} + check_over{M, x, r} # For +, infers other argument as r-x + } + } + if (not (same{op,+} and V==[4]f64)) { + def max_k = lb{vl/2} # Divide by two from assuming l≥2 + if (max_k<3 or l<4) small{max_k} else small{max_k-1} # l=2 and l=3 are the only cases needing the full max_k iterations; max_k<3 limits specialization to where it's significant + } else { # Non-associative! + c:= V**0 + if (l==2) { + @for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) { + a:= c + shuf{x, 0,1,0,1} + c = a + shuf{x, 2,3,2,3} + r = blend{a, c, 0,0,1,1} + } + } else { + assert{l==3} + @for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) { + a:= shuf{c, 1,1,2,3} + blend{x, V**0, 0,1,1,1} + r = c = x + shuf{a, 1,2,3,0} + } + } + } + } else { + # Large stride: single shift, with saved register or memory + def op_chk{M, p, x} = { r:= op{p, x}; check_over{M, p, x, r}; r } + @for (r, x over l) r = x + if (has_shuf and l<256/(wT/8)) { + # Make sure to load the previous row data at the same alignment to not hit bad store-to-load forwarding + def [il]IE = I + q:= l%vl; fq:= cast_i{IE, q*f} + def rot = shuf{IE, ., (iota{I} - I**fq) & I**(il-1)} + bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv} + c:= V**id{T} + o:= l - q + if (l == 2*vl) { o = vl; bv = ~bv } + if (o == vl) { + p:= load{*V~~x}; store{*V~~r, 0, p} + @for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, M in 'm' over ia-o) { + p = rot{p} + r = op_chk{M, bl{c, p}, x} + c = p; p = r + } + } else { + @for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r}, M in 'm' over ia-o) { + q:= rot{p} + r = op_chk{M, bl{c, q}, x} + c = q + } + } + } else if (same{op, +} and T<=i32 and has_simd and (has_shuf or l>=vl)) { + def vl = arch_defvw/wT; def V = [vl]T + @for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r}, M in 'm' over ia-l) { + r = op_chk{M, p, x} + } + } else { + @for (r, x, p in r-l over _ from l to ia) r = op_chk{0, p, x} + } + } + 1 +} +def scan_stride_assoc{op, T} = scan_stride_assoc{op, T, void, {..._}=>{}} +def check_add_over{_, w:T, x:T, r:T} = { if ((w^r) & (x^r) < 0) return{0} } +def check_add_over{M, w:V=[_]E, x:V, r:V} = { + o:= (if (not hasarch{'X86_64'} or width{E}<=16) any_hom{M, subs{r,w} != x} + else any_top{M, (w^r) & (x^r)}) + if (o) return{0} +} +def check_add_over{M, x, r} = check_add_over{M, r-x, x, r} +export_tab{'si_scan_stride_minmax', + flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}} +} +export_tab{'si_scan_stride_add', tup{ + ...each{scan_stride_assoc{+, ., u1, check_add_over}, tup{i8,i16,i32}}, + scan_stride_assoc{+, f64, u1, {..._}=>{}} +}} + + # xor scan def vec_prefix_byshift{op, sh} = { def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v @@ -106,13 +218,13 @@ def vec_prefix_byshift{op, sh} = { def scan_word_ne = prefix_byshift{^, <<} def scan_words_ne = vec_prefix_byshift{^, <<} -fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:usz) : void = { @for (x, r over nw) { r = c ^ scan_word_ne{x} c = -(r>>63) # repeat sign bit } } -fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:usz) : void = { def vl = arch_defvw / 64 def V = [vl]u64 c := V**c0 @@ -123,7 +235,7 @@ fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = { c = broadcast_last{p} } } -fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = { +fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:usz, mark:u64) : void = { def V = [2]u64 m := V**mark def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total @@ -144,10 +256,10 @@ fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64 store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1} } } -fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = { clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1)) } -fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = { def V = [8]u64 def sse{a} = make{[2]u64, a, 0} carry := sse{init} @@ -358,10 +470,11 @@ def loose_mask_gen{V=[vl]T, l} = { # Slow, for ≠` only } def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'} def loose_mask_gen{V=[vl](u64), l if has_vecshift} = { + l64 := promote{u64, l} q := -make{V, 64*iota{vl}} # distance to next row boundary - def q_mod{} = { q+= V**l & -(q>>63) } - def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l, q} } - o:u64 = width{V}; while (o>l) { o-=l; q_mod{} } + def q_mod{} = { q+= V**l64 & -(q>>63) } + def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l64, q} } + o:u64 = width{V}; while (o>l64) { o-=l64; q_mod{} } {} => { m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64 q-= V**o; q_mod{} @@ -560,7 +673,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { c:u64 = 0 # carry while (1) { i+= l; ii := iw; iw = cdiv{i, 64} - scan_neq{}(c, x+ii, r+ii, promote{u64,iw-ii}) + scan_neq{}(c, x+ii, r+ii, iw-ii) if (i == nl) return{} s:= load{r, iw-1} q := i%64 @@ -619,3 +732,49 @@ export{'si_scan_rows_and', scan_rows_andor{0}} export{'si_scan_rows_or', scan_rows_andor{1}} export{'si_scan_rows_ne', scan_rows_neq} export{'si_scan_rows_ltack', scan_rows_left} + + +# Strided boolean scans +fn scan_stride_bool_assoc{op}(x:*u64, r:*u64, nl:usz, l:usz) : void = { + assert{l > 1} + def {flip,opf} = if (same{op, &}) tup{~,|} else tup{{x}=>x,op} # such that identity of opf is 0 + nw:= cdiv{nl, 64} + if (l <= 64) { + if (same{op, ^} and hasarch{'PCLMUL'} and (l & (l-1)) == 0) { + clmul_scan_ne_any{}(*void~~x, *void~~r, 0, nw, aligned_spaced_mask{l}) + return{} + } + c:u64 = 0 # carry l bits, no matter the alignment + @for (r, x over nw) { + c = opf{flip{x}, c >> (64-l)} + s:= l; while (s < 64) { c = opf{c, c<>(64-q) | p<>(64-q) | p<,≠,=,≤,≥,⊣,⊢⟩ {f𝕊arr: ! ( F _k`_k ⥊arr) ≡ ⥊_k F` _eqvar arr}⌜ {𝕩∾≍˘¨𝕩} ≍˘¨ ⟨⋈1, 1‿0‿1⟩ +%USE eqvar ⋄ %USE k ⋄ ⟨+,-,×,÷,⋆,√,⌊,⌈,¬,∧,∨,<,>,≠,=,≤,≥,⊣,⊢⟩ {f𝕊arr: ! (1 F _k`_k ⥊arr) ≡ ⥊_k (1¨⊏𝕩)⊸(F`) _eqvar arr}⌜ {𝕩∾≍˘¨𝕩} ≍˘¨ ⟨⋈1, 1‿0‿1⟩ + # ´ !"𝔽´𝕩: 𝕩 must be a list (⟨⟩ ≡ ≢𝕩)" % +´0