Unify power-of-two shift pattern for scans

This commit is contained in:
Marshall Lochbaum 2023-08-10 17:04:55 -04:00
parent d16ba6c3b1
commit a040a14744

View File

@ -55,24 +55,26 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
scan_loop{T, init, x, r, len, scan, last} scan_loop{T, init, x, r, len, scan, last}
} }
# Make prefix scan from op and shifter by applying the operation
# at increasing power-of-two shifts
def prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈ # Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
def scan_idem = scan_scal def scan_idem = scan_scal
fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = { fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Within each lane, scan using shifts by powers of 2. First k elements # Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone. # when shifting by k don't need to change, so leave them alone.
def w = width{T}
def shift{k,l} = merge{iota{k},iota{l-k}} def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}} def shb{v, k} = sel8{v, shift{k/8,16}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a) def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
# Prefix op on entire AVX register def shb{v, k & k==128 & hasarch{'AVX2'}} = {
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1 # After lanewise scan, broadcast end of lane 0 to entire lane 1
if (not hasarch{'AVX2'}) b sel{[8]i32, spread{v}, make{[8]i32, 3*(3<iota{8})}}
else op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
} }
scan_post{T, init, x, r, len, op, prefix_byshift{op, shb}}
scan_post{T, init, x, r, len, op, pre}
} }
fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = { fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
def sc{a} = op{a, zipLo{a,a}} def sc{a} = op{a, zipLo{a,a}}
@ -96,18 +98,16 @@ export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', sca
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}} export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
# Assumes identity is 0 # Assumes identity is 0
def scan_assoc{op, a:T} = { def scan_assoc{op} = {
# Within each lane, scan using shifts by powers of 2 def shl0{v, k} = shl{[16]u8, v, k/8} # Lanewise
def w = elwidth{T} def shl0{v:V, k==128 & hasarch{'AVX2'}} = {
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a) # Broadcast end of lane 0 to entire lane 1
b:= c32{8, c32{4, c32{2, c32{1, a}}}} l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
if (not hasarch{'AVX2'}) b else { sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
} }
prefix_byshift{op, shl0}
} }
def scan_plus = scan_assoc{+, .} def scan_plus = scan_assoc{+}
# Associative scan # Associative scan
def scan_assoc_0 = scan_scal def scan_assoc_0 = scan_scal
@ -122,8 +122,7 @@ export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}}
# xor scan # xor scan
fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = { fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = {
@for (x, r over nw) { @for (x, r over nw) {
def sc{v, k} = if (k==64) v else sc{v ^ (v<<k), 2*k} r = p ^ prefix_byshift{^, <<}{x}
r = p ^ sc{x, 1}
p = -(r>>63) # repeat sign bit p = -(r>>63) # repeat sign bit
} }
} }