diff --git a/src/builtins/scan.c b/src/builtins/scan.c index 8b8402f6..128aa726 100644 --- a/src/builtins/scan.c +++ b/src/builtins/scan.c @@ -6,11 +6,11 @@ // Boolean operand, rank 1: // + AVX2 expansion (SHOULD have better generic, add SSE, NEON) // ∨⌈ ∧×⌊ search+copy, then memset (COULD vectorize search) -// ≠ SWAR shifts, CLMUL, VPCLMUL (SHOULD add SSE, NEON) +// ≠ SWAR/SIMD shifts, CLMUL, VPCLMUL (SHOULD add NEON polynomial mul) // < SWAR // =≤≥>- in terms of ≠<∨∧+ with adjustments // Arithmetic operand, rank 1: -// ⌈⌊ Scalar, SSE, AVX in log(vector width) steps (SHOULD add NEON) +// ⌈⌊ Scalar, SIMD in log(vector width) steps // Check in 6-vector blocks to quickly write result if constant // + Overflow-checked scalar or AVX2 // Ad-hoc boolean-valued handling for ≠∨ @@ -22,15 +22,19 @@ // SHOULD optimize dyadic scan with rank // Empty 𝕩, length 1, ⊢: return 𝕩 // Boolean operand, cell size 1: -// ≠∨∧⊣ and synonyms, rows <64: SWAR, AVX2 (SHOULD add SSE, NEON) +// ≠∨∧⊣ (and synonyms), rows <64: SWAR, SIMD // Power of two row size: autovectorized // COULD have dedicated SIMD for CPU widths, little improvement -// ⊣ SWAR for <64, select for ≥ -// ∨⌈ ∧×⌊ SWAR with addition for small rows, search for large -// Rows 64≤l<160: SWAR specialized for ≤1 boundary -// Large rows: word-at-a-time search -// ≠ power-of-two shifts for <64, rank-1 scans and boundary corrections if ≥ -// SHOULD have a better intermediate-size (< ~256) SIMD method +// COULD get unaligned row boundaries in 4x groups with & +// ≠∨∧⊣ medium rows (upper bound varies, <320): SIMD +// Generate boundary masks with index tracking and shifts +// Scan within words, propagate carries stopping at masks +// ≠ small and medium rows uses power-of-two shifts +// COULD try CLMUL +// ≠∨∧⊣ large rows: per-row loops +// ∨∧: word-at-a-time search +// ≠: rank-1 scans and boundary corrections +// ⊣: branchless boundary plus fixed-size loop // + scan in blocks, correct with mask, ⌊`, subtract // = as ≠`⌾¬, - as (2×⊣`)-+` // SHOULD optimize non-boolean scan with rank @@ -192,11 +196,12 @@ SHOULD_INLINE B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0 static B scan_lt(B x, u64 p, usz ia) { u64* xp = bitany_ptr(x); u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); - u64 m10 = 0x5555555555555555; + u64 m = 0x5555555555555555; for (usz i=0; i>63); - rp[i] = p = x & (m10 ^ (x + c)); + u64 u = -(p>>63) &~ (x+1); + u64 c = ((x<<1) | m) - x; + rp[i] = p = x & (m ^ c ^ u); } decG(x); return r; } diff --git a/src/singeli/src/base.singeli b/src/singeli/src/base.singeli index da2419e4..215ce648 100644 --- a/src/singeli/src/base.singeli +++ b/src/singeli/src/base.singeli @@ -177,7 +177,7 @@ def { all_hom,any_hom,blend_hom,hom_to_int,store_masked_hom,store_blended_hom, all_top,any_top,blend_top,top_to_int,store_masked_top,store_blended_top, load_expand_bits,make,mask_to_hom,mulw_split,mulh,narrow,narrow_trunc,narrow_pair, - pair,pdep,pext,rbit,sel,shuf_ind,reverse_units, + pair,pdep,pext,rbit,sel,shuf_ind,reverse_units,broadcast_sel, unord,unzip,vfold,vec_select,vec_shuffle,widen,widen_upper,multishift, } diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index 55be386c..3cff784c 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -6,6 +6,10 @@ include './mask' include './f64' include './spaced' include './scan_common' +if_inline (hasarch{'AARCH64'}) { + def __shl{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,b} + def __shr{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,-b} +} # Initialized scan, generic implementation fn scan_scal{T, op}(x:*T, r:*T, len:u64, m:T) : void = { @@ -35,7 +39,7 @@ def get_scan_last{op, pre} = { # Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈ def scan_idem = scan_scal -fn scan_idem{T, op if hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = { +fn scan_idem{T, op if has_simd}(x:*T, r:*T, len:u64, init:T) : void = { def {scan, last} = get_scan_last{op, make_scan_idem{T, op}} def cmp = match (op) { {(min)} => (>); {(max)} => (<) } def step = arch_defvw/width{T} @@ -78,19 +82,46 @@ def scan_plus = scan_assoc_id0{+} # Associative scan def scan_assoc_0 = scan_scal -fn scan_assoc_0{T, op if hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = { - # Prefix op on entire AVX register +fn scan_assoc_0{T, op if has_simd}(x:*T, r:*T, len:u64, init:T) : void = { + # Prefix op on entire SIMD register scan_loop{init, x, r, len, ...get_scan_last{op, scan_plus}} } export{'si_scan_pluswrap_u8', scan_assoc_0{u8 , +}} export{'si_scan_pluswrap_u16', scan_assoc_0{u16, +}} export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}} +def rotate_right{x:[l]_} = shuf{x, (iota{l}-1)%l} +def broadcast_last{x:[l]_} = shuf{x, l**(l-1)} +def broadcast_last{x:[l]_ if hasarch{'AARCH64'}} = broadcast_sel{x, l-1} +def blend_first{x:V=[l]_, y:V} = blend{x, y, 0 < iota{l}} +def shift_first{c:V=[l]_, p:V} = { + if (l==2) zip{c, p, 0} + else blend_first{c, rotate_right{p}} +} + # xor scan -fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = { +def vec_prefix_byshift{op, sh} = { + def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v + {v:T} => pre{v, 1} +} +def scan_word_ne = prefix_byshift{^, <<} +def scan_words_ne = vec_prefix_byshift{^, <<} + +fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = { @for (x, r over nw) { - r = p ^ prefix_byshift{^, <<}{x} - p = -(r>>63) # repeat sign bit + r = c ^ scan_word_ne{x} + c = -(r>>63) # repeat sign bit + } +} +fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = { + def vl = arch_defvw / 64 + def V = [vl]u64 + c := V**c0 + @for_masked{vl} (x in tup{V, x}, r in tup{V, r} over nw) { + s:= scan_words_ne{x} + p:= scan_assoc_id0{^}{-(s>>63)} ^ c + r = s ^ shift_first{c, p} + c = broadcast_last{p} } } fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = { @@ -312,6 +343,33 @@ export{'si_scan_plus_i32_f64', plus_scanG{i32, f64}} # Row-wise boolean scan +# Create masks of the given type with spacing l>=64 +def loose_mask_gen{(u64), l} = { + q:usz = 0 # distance to next row boundary + {} => { + b:= q<64 # whether there's a boundary + p:= q%64 # its position + q-= 64 - (l &- b) + promote{u64, b} << p + } +} +def loose_mask_gen{V=[vl]T, l} = { # Slow, for ≠` only + def get = loose_mask_gen{T, l} + {} => make{V, @collect (vl) get{}} +} +def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'} +def loose_mask_gen{V=[vl](u64), l if has_vecshift} = { + q := -make{V, 64*iota{vl}} # distance to next row boundary + def q_mod{} = { q+= V**l & -(q>>63) } + def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l, q} } + o:u64 = width{V}; while (o>l) { o-=l; q_mod{} } + {} => { + m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64 + q-= V**o; q_mod{} + m + } +} + def loop_with_unaligned_mask{x, r, nw, l, step} = { {m, d} := unaligned_spaced_mask_mod{l} c:u64 = 0 # carry (initial value never matters) @@ -323,29 +381,68 @@ def loop_with_unaligned_mask{x, r, nw, l, step} = { m = m>>d | m<<(l-d) } } -def avx2_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = { +def vec_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = { + def vl = arch_defvw / 64 + def V = [vl]u64 {ms, d} := unaligned_spaced_mask_mod{l} - def V = [4]u64 d4:usz = width{V} % l - m:= make{V, scan{{a,_} => a>>d | a<<(l-d), tup{ms, ...iota{3}}}} + m:= make{V, scan{{a,_} => a>>d | a<<(l-d), tup{ms, ...iota{vl-1}}}} c:= V**0 - @for_masked{4} (x in tup{V, xp}, - r in tup{V, rp} over promote{u64,nw}) { + @for_masked{vl} (x in tup{V, xp}, r in tup{V, rp} over promote{u64,nw}) { s := scan_words{x, m} - pc:= c; c = shuf{-(s>>63), 3,0,1,2} - r = apply_carry{s, blend{c, pc, 1,0,0,0}, (m-V**1)&~m} + # Each result word can be modified based on top bit of previous + t := -(s>>63) + pc:= c; c = rotate_right{t} + b := (if (vl==2) zip{pc, t, 0} else blend_first{pc, c}) + # Carry applies to bits below any mask bit + r = apply_carry{s, b, (m-V**1)&~m} m = m>>d4 | m<<(l-d4) } } +def vec_loop_with_loose_mask{xp, rp, nw, l, id, scan_words, propagate, fix_carry, apply_carry} = { + assert{l >= 64} + def vl = arch_defvw / 64 + def V = [vl]u64 + def get_m = loose_mask_gen{V, l} + c := V**id # carry, 0 or 1 + @for_masked{vl} (x in tup{V, xp}, r in tup{V, rp} over nw) { + # Get mask; <=1 bit per word + m:= get_m{} + # Within-word scan and carry info + ml:= m - V**1 + {s, k}:= scan_words{x, m, ml} + # Propagate carries and adjust result + p:= propagate{k, c} + t:= shift_first{c, p} + r = apply_carry{s, -fix_carry{t}, ml} + c = broadcast_last{p} + } +} +def vec_loop_with_loose_mask{...a={xp, rp, nw, l, id, scan_words}, apply_carry} = { + def passthrough{k, c} = { + def bl{b,a} = b ^ ((b^a) & -(b>>63)) + def bl{b:B,a if hasarch{'AARCH64'}} = blend_bit{b,a, ty_s{b} < 0} + def bl{b,a if hasarch{'SSE4.1'}} = blend_top{b,a, b} + bl{make_scan_idem{f64, bl}{k}, c} # Can't be -1 now + } + vec_loop_with_loose_mask{...a, passthrough, {k}=>k, apply_carry} +} fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { def qand = not id assert{l > 0} nw := cdiv{nl, 64} + def scan_mask{x:T, m:T} = { + if (qand) { p:= (x &~ m) >> 1; (x - p) ^ p } + else { p:= (x | m) >> 1; (p - x) ^ p } + } def res_m1{x,c,m} = { # result word with carry c, popc{m}<=1 if (qand) x &~ ((x+c) & (x+m)) else x | ((-x-c) &~ (x-m)) } + def apply_carry{s, c, f} = { + if (qand) s & (~f | c) else s | (f & c) + } if (l < 64) { if ((l & (l-1)) == 0) { if (l == 2) { @@ -354,25 +451,12 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { } } else { m:u64 = aligned_spaced_mask{l} - t := m << (l-1) - @for (r in dst, x in src over nw) { - r = (if (qand) x &~ ((t&x) ^ ((x&~t) + m)) - else x | ~((t&~x) ^ ((x|t) - m))) - } + @for (r in dst, x in src over nw) r = scan_mask{x, m} } # could use for l>=8; not much faster and takes up space # def rowwise{T} = @for (r in *T~~dst, x in *T~~src over (64/width{T})*nw) r = x &~ (x+1) - } else if (hasarch{'AVX2'}) { - def scan_words{x, m:V} = { - mb:= m | V**1 - p:= if (qand) (x &~ m) >> 1 else ~(x | m) >> 1 - a:= if (qand) p + (mb & x) else p + (mb &~ x) - if (qand) p ^ a else ~(p ^ a) - } - def apply_carry{s, c, f} = { - if (qand) s & (~f | c) else s | (f & c) - } - avx2_loop_with_unaligned_mask{src, dst, nw, l, scan_words, apply_carry} + } else if (has_simd) { + vec_loop_with_unaligned_mask{src, dst, nw, l, scan_mask, apply_carry} } else { loop_with_unaligned_mask{src, dst, nw, l, {x, c, m} => { s:= (if (qand) (x &~ m) >> 1 else ~(x | m) >> 1 ) @@ -381,15 +465,24 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { a >> 63} # new c }} } - } else if (l < 160) { - q:usz = 0 # distance to next row boundary - c:u64 = id # carry - @for (r in dst, x in src over nw) { - b:= q<64 # whether there's a boundary - p:= q%64 # its position - q-= 64 - (l &- b) - r = res_m1{x, c, promote{u64, b} << p} - c = r >> 63 + } else if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and l < (if (hasarch{'AVX2'}) 256 else 160)) { + if (hasarch{'SSE4.1'}) { + def scan_words{x:V, m:V, _} = { + s:= (if (qand) x &~ ((x+V**1) & (x+m)) + else x | ((-x) &~ (x-m))) + p:= (if (qand) x&~m == ~V**0 + else x| m == V**0) + k:= s>>63 | p # Carry of 0 or 1, but -1 to propagate previous + tup{s, k} + } + vec_loop_with_loose_mask{src, dst, nw, l, id, scan_words, apply_carry} + } else { + def get_m = loose_mask_gen{u64, l} + c:u64 = id # carry + @for (r in dst, x in src over nw) { + r = res_m1{x, c, get_m{}} + c = r >> 63 + } } } else { i :usz = 0 # row bit index @@ -422,29 +515,22 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { } fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { - def scan_word = prefix_byshift{^, <<} + def scan_word = scan_word_ne + def scan_words = scan_words_ne + def apply_carry{s, c, f} = s ^ (f & c) assert{l > 0} nw := cdiv{nl, 64} if (l < 64) { + def apply_mask{s, m} = { + b:= s<<1 & m # last bit of previous row + s ^ (b< pre{v, 1} - } - s:= vec_prefix_byshift{^, <<}{x} - b:= s<<1 & m # last bit of previous row - s ^ (b< { s:= scan_word{x} @@ -453,6 +539,22 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { s ^ ((c & f) | (b<>63 | (V**(1<<63) &~ ml) # Top bit 1 to stop, so 0 is identity + tup{s, k} + } + def propagate{k:V=[vl]_, c:V} = { + def bl{b,a} = b ^ (a &~ -(b>>63)) + def bl{b,a if hasarch{'AVX2'}} = blend_top{a^b,b, b} + k = bl{k, vec_shift_right_128{k, 1}} + if (vl>2) k = bl{k, shuf{V, blend{V**0, k, 0,1,0,1}, 0,0,1,1}} + bl{k, c} + } + def fix_carry{t:V} = t & V**1 + vec_loop_with_loose_mask{x, r, nw, l, 0, scan_words, propagate, fix_carry, apply_carry} } else { i :usz = 0 # row bit index iw:usz = 0 # starting word @@ -471,36 +573,44 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { } fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = { - def scan_word = prefix_byshift{^, <<} assert{l > 0} nw := cdiv{nl, 64} + def apply_carry{s, c, f} = s | (f & c) if (l < 64) { + def apply_mask{x, m} = { b:= x & m; b< { f:= (m-1)&~m # bits before first full row - b:= x & m - (c & f) | (b<>63 | (m == V**0)} + } + vec_loop_with_loose_mask{x, r, nw, l, 0, scan_words, apply_carry} } else { - i :usz = 0 # row bit index - wn:usz = 0 # starting word of next row - c:u64 = 0 # carry + assert{l >= 64} + k:= l/64 - 1 # at least k full aligned words in a row + i :usz = 0 # row bit index + wn:usz = 0 # starting word of next row + c:u64 = 0 # carry we:= nl/64; while (wn < we) { iw:= wn m := u64~~1 << (i%64) xw:= -(load{x, iw} & m) - store{r, iw, (c & (m-1)) | xw} + r0:= (c & (m-1)) | xw c = -(xw>>63) i+= l; wn = i/64 - @for (r in r over _ from iw+1 to wn) r = c + store{r, wn-1, c} + store{r, iw, r0} + @for (r in r+iw+1 over k) r = c } if (i%64 != 0) store{r, wn, c} } diff --git a/src/singeli/src/scan_common.singeli b/src/singeli/src/scan_common.singeli index c320351d..b2f425e8 100644 --- a/src/singeli/src/scan_common.singeli +++ b/src/singeli/src/scan_common.singeli @@ -1,5 +1,6 @@ # Used by scan.singeli and bins.singeli +def has_sel8 = hasarch{'SSSE3'} or hasarch{'AARCH64'} def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}} def sel8{v:V, t if w256{V} and istup{t} and length{t}==16} = sel8{v, merge{t,t}} @@ -17,9 +18,9 @@ def spread{a:[_]T, ...up} = { } # Set all elements with the last element of the input -def toLast{n:VT, up if hasarch{'X86_64'} and w128{VT}} = { +def toLast{n:VT, up if has_simd and w128{VT}} = { def l{v, w} = l{zip{up,v}, 2*w} - def l{v, w if hasarch{'SSSE3'}} = sel8{v, up*(16-w/8)+iota{16}%(w/8)} + def l{v, w if has_sel8} = sel8{v, up*(16-w/8)+iota{16}%(w/8)} def l{v, w==32} = shuf{[4]i32, v, 4**(up*3)} def l{v, w==64} = shuf{[2]i64, v, 2** up } l{n, elwidth{VT}} @@ -49,7 +50,7 @@ def make_scan_idem{T, op, up} = { def id = make{V, merger{c**get_id{op,T}, (width{V}/w-c)**0}} (if (up) vec_shift_right_128 else vec_shift_left_128){v, c} | id } - def shb{v, k if hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}} + def shb{v, k if has_sel8} = sel8{v, shift{k/8,16}} def shb{v, k if k>=32} = shuf{[4]u32, v, shift{k/32,4}} def shb{v, k if k==128 and hasarch{'AVX2'}} = { # After lanewise scan, broadcast end of lane 0 to entire lane 1 @@ -69,10 +70,14 @@ def make_scan_idem{T, op} = make_scan_idem{T, op, 1} def scan_assoc_id0{op} = { def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise - def shl0{v:V, k==128 if hasarch{'AVX2'}} = { + def shl0{v:V=[_]T, k==128 if hasarch{'AVX2'}} = { # Broadcast end of lane 0 to entire lane 1 - l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v} - sel{[8]i32, l, make{[8]i32, 3*(3