Unify Scan and Bins min-scan code

This commit is contained in:
Marshall Lochbaum 2023-08-26 08:20:30 -04:00
parent 0a30fb309d
commit c85ca66dae
3 changed files with 81 additions and 82 deletions

View File

@ -4,6 +4,8 @@ if (hasarch{'AVX2'}) {
include './sse'
include './avx'
include './avx2'
} else if (hasarch{'X86_64'}) {
include './sse'
}
include './mask'
include 'util/tup'
@ -27,36 +29,18 @@ def shr16{v:V, n} = V~~(re_el{u16, v} >> n)
# Forward or backwards in-place max-scan
# Assumes a whole number of vectors and minimum 0
include './scan_common'
fn max_scan{T, up}(x:*T, len:u64) : void = {
def w = width{T}
if (hasarch{'AVX2'} and T!=u64) {
if (hasarch{'X86_64'}) {
def op = max
# TODO unify with scan.singeli avx2_scan_idem
def rev{a} = if (up) a else (tuplen{a}-1)-reverse{a}
def maker{T, l} = make{T, rev{l}}
def sel8{v, t} = sel{[16]u8, v, maker{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,rev{n}}}
def spread{a:VT} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
op{b, sel{[8]i32, spread{b}, maker{[8]i32, 3*(3<iota{8})}}}
}
def toLast{n:VT} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**(up*7)}
else shuf{[4]u64, n, up*4b3333}
}
def vl = 256/w
def pre = make_scan_idem{T, op, up}
def vl = (if (hasarch{'AVX2'}) 256 else 128)/width{T}
def V = [vl]T
p := V**0
@for_dir{up} (v in *V~~x over len/vl) { v = op{pre{v}, p}; p = toLast{v} }
@for_dir{up} (v in *V~~x over len/vl) {
v = op{pre{v}, p}
p = toLast{v, up}
}
} else {
m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m }
}

View File

@ -5,36 +5,13 @@ include './avx'
include './avx2'
include './mask'
include './f64'
include './scan_common'
# Initialized scan, generic implementation
fn scan_scal{T, op}(x:*T, r:*T, len:u64, m:T) : void = {
@for (x, r over len) r = m = op{m, x}
}
def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}}
def sel8{v:V, t & w256{V} & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
# Fill last 4 bytes with last element, in each lane
def spread{a:VT} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
# Set all elements with the last element of the input
def toLast{n:VT & hasarch{'X86_64'} & w128{VT}} = {
def l{v, w} = l{zipHi{v,v}, 2*w}
def l{v, w & hasarch{'SSSE3'}} = sel8{v, (16-w/8)+iota{16}%(w/8)}
def l{v, w & w>=32} = shuf{[4]i32, v, 4**3}
l{n, elwidth{VT}}
}
def toLast{n:VT & hasarch{'AVX2'} & w256{VT}} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7}
else shuf{[4]u64, n, 4**3}
}
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
def step = arch_defvw/width{T}
def V = [step]T
@ -56,41 +33,10 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
scan_loop{T, init, x, r, len, scan, last}
}
# Make prefix scan from op and shifter by applying the operation
# at increasing power-of-two shifts
def prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
}
def get_id{op,T} = (match (op) { {_==min}=>maxvalue; {_==max}=>minvalue }){T}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
def scan_idem = scan_scal
fn scan_idem{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def shift{k,l} = merge{iota{k},iota{l-k}}
def shb{v:V, k} = {
def w=width{T}; def c = k/w
def id = make{V, merge{c**get_id{op,T}, (width{V}/w-c)**0}}
shl{[16]u8, v, k/8} | id
}
def shb{v, k & hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}}
def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
def shb{v, k & k==128 & hasarch{'AVX2'}} = {
# After lanewise scan, broadcast end of lane 0 to entire lane 1
sel{[8]i32, spread{v}, make{[8]i32, 3*(3<iota{8})}}
}
scan_post{T, init, x, r, len, op, prefix_byshift{op, shb}}
}
fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
def sc{a} = op{a, zipLo{a,a}}
def sc{a & hasarch{'AVX2'}} = {
def sh{s, a} = op{a, shuf{[4]u64, a, s}}
sh{4b1110,sh{4b2200,a}}
}
scan_post{T, init, x, r, len, op, sc}
scan_post{T, init, x, r, len, op, make_scan_idem{T, op}}
}
export{'si_scan_min_init_i8', scan_idem{i8 , min}}; export{'si_scan_max_init_i8', scan_idem{i8 , max}}

View File

@ -0,0 +1,69 @@
# Used by scan.singeli and bins.singeli
def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}}
def sel8{v:V, t & w256{V} & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
local def rev{t} = { def l=tuplen{t}; def j=l-1; tupsel{j-range{l}, j-t} }
local def rev{up,t} = if (up) t else rev{t}
def sel8{v, t, up} = sel8{v, rev{up,t}}
def zip{up, x} = (if (up) zipHi else zipLo){x,x}
# Fill last 4 bytes with last element, in each lane
def spread{a:VT, ...up} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}, ...up}; else a
}
# Set all elements with the last element of the input
def toLast{n:VT, up & hasarch{'X86_64'} & w128{VT}} = {
def l{v, w} = l{zip{up,v}, 2*w}
def l{v, w & hasarch{'SSSE3'}} = sel8{v, up*(16-w/8)+iota{16}%(w/8)}
def l{v, w & w>=32} = shuf{[4]i32, v, 4**(up*3)}
l{n, elwidth{VT}}
}
def toLast{n:VT, up & hasarch{'AVX2'} & w256{VT}} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n,up}, [8]i32**(up*7)}
else shuf{[4]u64, n, 4**(up*3)}
}
def toLast{n:VT} = toLast{n, 1}
# Make prefix scan from op and shifter by applying the operation
# at increasing power-of-two shifts
def prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
}
def get_id{op,T} = (match (op) { {_==min}=>maxvalue; {_==max}=>minvalue }){T}
def make_scan_idem{T, op, up} = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def shift{k,l} = rev{up, merge{iota{k},iota{l-k}}}
def shb{v:V, k} = {
def w=width{T}; def c = k/w
def merger{a,b} = if (up) merge{a,b} else merge{b,a}
def id = make{V, merger{c**get_id{op,T}, (width{V}/w-c)**0}}
(if (up) shl else shr){[16]u8, v, k/8} | id
}
def shb{v, k & hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}}
def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
def shb{v, k & k==128 & hasarch{'AVX2'}} = {
# After lanewise scan, broadcast end of lane 0 to entire lane 1
sel{[8]i32, spread{v,up}, make{[8]i32, rev{up,3*(3<iota{8})}}}
}
prefix_byshift{op, shb}
}
def make_scan_idem{T==f64, op, up} = {
def sc{a} = op{a, zip{up,a}}
def sc{a & hasarch{'AVX2'}} = {
def sh{s, a} = op{a, shuf{[4]u64, a, rev{up,s}}}
sh{tup{0,1,1,1},sh{tup{0,0,2,2},a}}
}
sc
}
def make_scan_idem{T, op} = make_scan_idem{T, op, 1}