Unify power-of-two shift pattern for scans
This commit is contained in:
parent
d16ba6c3b1
commit
a040a14744
@ -55,24 +55,26 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
|
||||
scan_loop{T, init, x, r, len, scan, last}
|
||||
}
|
||||
|
||||
# Make prefix scan from op and shifter by applying the operation
|
||||
# at increasing power-of-two shifts
|
||||
def prefix_byshift{op, sh} = {
|
||||
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
|
||||
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
|
||||
}
|
||||
|
||||
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
|
||||
def scan_idem = scan_scal
|
||||
fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = {
|
||||
# Within each lane, scan using shifts by powers of 2. First k elements
|
||||
# when shifting by k don't need to change, so leave them alone.
|
||||
def w = width{T}
|
||||
def shift{k,l} = merge{iota{k},iota{l-k}}
|
||||
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
|
||||
# Prefix op on entire AVX register
|
||||
def pre{a} = {
|
||||
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
|
||||
def shb{v, k} = sel8{v, shift{k/8,16}}
|
||||
def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
|
||||
def shb{v, k & k==128 & hasarch{'AVX2'}} = {
|
||||
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
||||
if (not hasarch{'AVX2'}) b
|
||||
else op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
|
||||
sel{[8]i32, spread{v}, make{[8]i32, 3*(3<iota{8})}}
|
||||
}
|
||||
|
||||
scan_post{T, init, x, r, len, op, pre}
|
||||
scan_post{T, init, x, r, len, op, prefix_byshift{op, shb}}
|
||||
}
|
||||
fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
|
||||
def sc{a} = op{a, zipLo{a,a}}
|
||||
@ -96,18 +98,16 @@ export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', sca
|
||||
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
|
||||
|
||||
# Assumes identity is 0
|
||||
def scan_assoc{op, a:T} = {
|
||||
# Within each lane, scan using shifts by powers of 2
|
||||
def w = elwidth{T}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
|
||||
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
|
||||
if (not hasarch{'AVX2'}) b else {
|
||||
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
||||
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
|
||||
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
|
||||
def scan_assoc{op} = {
|
||||
def shl0{v, k} = shl{[16]u8, v, k/8} # Lanewise
|
||||
def shl0{v:V, k==128 & hasarch{'AVX2'}} = {
|
||||
# Broadcast end of lane 0 to entire lane 1
|
||||
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
|
||||
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
|
||||
}
|
||||
prefix_byshift{op, shl0}
|
||||
}
|
||||
def scan_plus = scan_assoc{+, .}
|
||||
def scan_plus = scan_assoc{+}
|
||||
|
||||
# Associative scan
|
||||
def scan_assoc_0 = scan_scal
|
||||
@ -122,8 +122,7 @@ export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}}
|
||||
# xor scan
|
||||
fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = {
|
||||
@for (x, r over nw) {
|
||||
def sc{v, k} = if (k==64) v else sc{v ^ (v<<k), 2*k}
|
||||
r = p ^ sc{x, 1}
|
||||
r = p ^ prefix_byshift{^, <<}{x}
|
||||
p = -(r>>63) # repeat sign bit
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user