Unify power-of-two shift pattern for scans

This commit is contained in:
Marshall Lochbaum 2023-08-10 17:04:55 -04:00
parent d16ba6c3b1
commit a040a14744

View File

@ -55,24 +55,26 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
scan_loop{T, init, x, r, len, scan, last}
}
# Make prefix scan from op and shifter by applying the operation
# at increasing power-of-two shifts
def prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
def scan_idem = scan_scal
fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def w = width{T}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
# Prefix op on entire AVX register
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
def shb{v, k} = sel8{v, shift{k/8,16}}
def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
def shb{v, k & k==128 & hasarch{'AVX2'}} = {
# After lanewise scan, broadcast end of lane 0 to entire lane 1
if (not hasarch{'AVX2'}) b
else op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
sel{[8]i32, spread{v}, make{[8]i32, 3*(3<iota{8})}}
}
scan_post{T, init, x, r, len, op, pre}
scan_post{T, init, x, r, len, op, prefix_byshift{op, shb}}
}
fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
def sc{a} = op{a, zipLo{a,a}}
@ -96,18 +98,16 @@ export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', sca
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
# Assumes identity is 0
def scan_assoc{op, a:T} = {
# Within each lane, scan using shifts by powers of 2
def w = elwidth{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
if (not hasarch{'AVX2'}) b else {
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
def scan_assoc{op} = {
def shl0{v, k} = shl{[16]u8, v, k/8} # Lanewise
def shl0{v:V, k==128 & hasarch{'AVX2'}} = {
# Broadcast end of lane 0 to entire lane 1
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
}
prefix_byshift{op, shl0}
}
def scan_plus = scan_assoc{+, .}
def scan_plus = scan_assoc{+}
# Associative scan
def scan_assoc_0 = scan_scal
@ -122,8 +122,7 @@ export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}}
# xor scan
fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = {
@for (x, r over nw) {
def sc{v, k} = if (k==64) v else sc{v ^ (v<<k), 2*k}
r = p ^ sc{x, 1}
r = p ^ prefix_byshift{^, <<}{x}
p = -(r>>63) # repeat sign bit
}
}