diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index 0d7fc7fd..566eabca 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -55,24 +55,26 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = { scan_loop{T, init, x, r, len, scan, last} } +# Make prefix scan from op and shifter by applying the operation +# at increasing power-of-two shifts +def prefix_byshift{op, sh} = { + def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v + {v:T} => pre{v, if (isvec{T}) elwidth{T} else 1} +} + # Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈ def scan_idem = scan_scal fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = { # Within each lane, scan using shifts by powers of 2. First k elements # when shifting by k don't need to change, so leave them alone. - def w = width{T} def shift{k,l} = merge{iota{k},iota{l-k}} - def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}} - def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a) - # Prefix op on entire AVX register - def pre{a} = { - b:= c8{2, c8{1, c32{2, c32{1, a}}}} + def shb{v, k} = sel8{v, shift{k/8,16}} + def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}} + def shb{v, k & k==128 & hasarch{'AVX2'}} = { # After lanewise scan, broadcast end of lane 0 to entire lane 1 - if (not hasarch{'AVX2'}) b - else op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3>63) # repeat sign bit } }