unify toLast and broadcast_last

This commit is contained in:
dzaima 2025-03-02 02:03:28 +02:00
parent 9826c4ce0e
commit 445007b38c
3 changed files with 9 additions and 9 deletions

View File

@ -34,7 +34,7 @@ fn max_scan{T, up}(x:*T, len:u64) : void = {
p := V**0 p := V**0
@for_dir{up} (v in *V~~x over len/vl) { @for_dir{up} (v in *V~~x over len/vl) {
v = op{pre{v}, p} v = op{pre{v}, p}
p = toLast{v, up} p = broadcast_last{v, up}
} }
} else { } else {
m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m } m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m }

View File

@ -31,7 +31,7 @@ def get_scan_last{op, pre} = {
def last{v, p} = op{pre{v}, p} def last{v, p} = op{pre{v}, p}
def scan{v, p} = { def scan{v, p} = {
n:= last{v, p} n:= last{v, p}
p = toLast{n} p = broadcast_last{n}
n n
} }
tup{scan, last} tup{scan, last}
@ -91,8 +91,7 @@ export{'si_scan_pluswrap_u16', scan_assoc_0{u16, +}}
export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}} export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}}
def rotate_right{x:[l]_} = shuf{x, (iota{l}-1)%l} def rotate_right{x:[l]_} = shuf{x, (iota{l}-1)%l}
def broadcast_last{x:[l]_} = shuf{x, l**(l-1)}
def broadcast_last{x:[l]_ if hasarch{'AARCH64'}} = broadcast_sel{x, l-1}
def blend_first{x:V=[l]_, y:V} = blend{x, y, 0 < iota{l}} def blend_first{x:V=[l]_, y:V} = blend{x, y, 0 < iota{l}}
def shift_first{c:V=[l]_, p:V} = { def shift_first{c:V=[l]_, p:V} = {
if (l==2) zip{c, p, 0} if (l==2) zip{c, p, 0}
@ -315,10 +314,10 @@ def simd_plus_scan_part{x:*X, c:R, r:*R, len:(u64), i:(u64)} = {
def s0 = each{scan_plus, cx} def s0 = each{scan_plus, cx}
def s1{v0} = tup{v0} def s1{v0} = tup{v0}
def s1{v0,v1} = tup{v0,v1+toLast{v0}} def s1{v0,v1} = tup{v0,v1+broadcast_last{v0}}
def cr = eachx{+, widenFull{R, s1{...s0}}, cv} def cr = eachx{+, widenFull{R, s1{...s0}}, cv}
cv = toLast{select{cr, -1}} cv = broadcast_last{select{cr, -1}}
assert{type{cv} == one_type{cr}} assert{type{cv} == one_type{cr}}
assert{vcount{type{cv}} * length{cr} == bulk} assert{vcount{type{cv}} * length{cr} == bulk}

View File

@ -18,18 +18,19 @@ def spread{a:[_]T, ...up} = {
} }
# Set all elements with the last element of the input # Set all elements with the last element of the input
def toLast{n:VT, up if has_simd and w128{VT}} = { def broadcast_last{n:VT, up if has_simd and w128{VT}} = {
def l{v, w} = l{zip{up,v}, 2*w} def l{v, w} = l{zip{up,v}, 2*w}
def l{v, w if has_sel8} = sel8{v, up*(16-w/8)+iota{16}%(w/8)} def l{v, w if has_sel8} = sel8{v, up*(16-w/8)+iota{16}%(w/8)}
def l{v, w==32} = shuf{[4]i32, v, 4**(up*3)} def l{v, w==32} = shuf{[4]i32, v, 4**(up*3)}
def l{v, w==64} = shuf{[2]i64, v, 2** up } def l{v, w==64} = shuf{[2]i64, v, 2** up }
l{n, elwidth{VT}} l{n, elwidth{VT}}
} }
def toLast{n:VT, up if hasarch{'AVX2'} and w256{VT}} = { def broadcast_last{n:VT, up if hasarch{'AVX2'} and w256{VT}} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n,up}, [8]i32**(up*7)} if (elwidth{VT}<=32) sel{[8]i32, spread{n,up}, [8]i32**(up*7)}
else shuf{[4]u64, n, 4**(up*3)} else shuf{[4]u64, n, 4**(up*3)}
} }
def toLast{n:VT} = toLast{n, 1} def broadcast_last{n:[k]_, up if hasarch{'AARCH64'}} = broadcast_sel{n, if (up) k-1 else 0}
def broadcast_last{n:VT} = broadcast_last{n, 1}
# Make prefix scan from op and shifter by applying the operation # Make prefix scan from op and shifter by applying the operation
# at increasing power-of-two shifts # at increasing power-of-two shifts