From c85ca66daee864ab3871f19d5896613e36139fab Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Sat, 26 Aug 2023 08:20:30 -0400 Subject: [PATCH] Unify Scan and Bins min-scan code --- src/singeli/src/bins.singeli | 36 +++++---------- src/singeli/src/scan.singeli | 58 +----------------------- src/singeli/src/scan_common.singeli | 69 +++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 82 deletions(-) create mode 100644 src/singeli/src/scan_common.singeli diff --git a/src/singeli/src/bins.singeli b/src/singeli/src/bins.singeli index 450a044b..b4af2a70 100644 --- a/src/singeli/src/bins.singeli +++ b/src/singeli/src/bins.singeli @@ -4,6 +4,8 @@ if (hasarch{'AVX2'}) { include './sse' include './avx' include './avx2' +} else if (hasarch{'X86_64'}) { + include './sse' } include './mask' include 'util/tup' @@ -27,36 +29,18 @@ def shr16{v:V, n} = V~~(re_el{u16, v} >> n) # Forward or backwards in-place max-scan # Assumes a whole number of vectors and minimum 0 +include './scan_common' fn max_scan{T, up}(x:*T, len:u64) : void = { - def w = width{T} - if (hasarch{'AVX2'} and T!=u64) { + if (hasarch{'X86_64'}) { def op = max - # TODO unify with scan.singeli avx2_scan_idem - def rev{a} = if (up) a else (tuplen{a}-1)-reverse{a} - def maker{T, l} = make{T, rev{l}} - def sel8{v, t} = sel{[16]u8, v, maker{[32]i8, t}} - def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}} - def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,rev{n}}} - def spread{a:VT} = { - def w = elwidth{VT} - def b = w/8 - if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a - } - def shift{k,l} = merge{iota{k},iota{l-k}} - def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}} - def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a) - def pre{a} = { - b:= c8{2, c8{1, c32{2, c32{1, a}}}} - op{b, sel{[8]i32, spread{b}, maker{[8]i32, 3*(3 m) m = x; x = m } } diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index 78da11ca..773e8675 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -5,36 +5,13 @@ include './avx' include './avx2' include './mask' include './f64' +include './scan_common' # Initialized scan, generic implementation fn scan_scal{T, op}(x:*T, r:*T, len:u64, m:T) : void = { @for (x, r over len) r = m = op{m, x} } -def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}} -def sel8{v:V, t & w256{V} & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}} - -def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}} - -# Fill last 4 bytes with last element, in each lane -def spread{a:VT} = { - def w = elwidth{VT} - def b = w/8 - if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a -} - -# Set all elements with the last element of the input -def toLast{n:VT & hasarch{'X86_64'} & w128{VT}} = { - def l{v, w} = l{zipHi{v,v}, 2*w} - def l{v, w & hasarch{'SSSE3'}} = sel8{v, (16-w/8)+iota{16}%(w/8)} - def l{v, w & w>=32} = shuf{[4]i32, v, 4**3} - l{n, elwidth{VT}} -} -def toLast{n:VT & hasarch{'AVX2'} & w256{VT}} = { - if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7} - else shuf{[4]u64, n, 4**3} -} - def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = { def step = arch_defvw/width{T} def V = [step]T @@ -56,41 +33,10 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = { scan_loop{T, init, x, r, len, scan, last} } -# Make prefix scan from op and shifter by applying the operation -# at increasing power-of-two shifts -def prefix_byshift{op, sh} = { - def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v - {v:T} => pre{v, if (isvec{T}) elwidth{T} else 1} -} - -def get_id{op,T} = (match (op) { {_==min}=>maxvalue; {_==max}=>minvalue }){T} - # Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈ def scan_idem = scan_scal fn scan_idem{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = { - # Within each lane, scan using shifts by powers of 2. First k elements - # when shifting by k don't need to change, so leave them alone. - def shift{k,l} = merge{iota{k},iota{l-k}} - def shb{v:V, k} = { - def w=width{T}; def c = k/w - def id = make{V, merge{c**get_id{op,T}, (width{V}/w-c)**0}} - shl{[16]u8, v, k/8} | id - } - def shb{v, k & hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}} - def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}} - def shb{v, k & k==128 & hasarch{'AVX2'}} = { - # After lanewise scan, broadcast end of lane 0 to entire lane 1 - sel{[8]i32, spread{v}, make{[8]i32, 3*(3=32} = shuf{[4]i32, v, 4**(up*3)} + l{n, elwidth{VT}} +} +def toLast{n:VT, up & hasarch{'AVX2'} & w256{VT}} = { + if (elwidth{VT}<=32) sel{[8]i32, spread{n,up}, [8]i32**(up*7)} + else shuf{[4]u64, n, 4**(up*3)} +} +def toLast{n:VT} = toLast{n, 1} + +# Make prefix scan from op and shifter by applying the operation +# at increasing power-of-two shifts +def prefix_byshift{op, sh} = { + def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v + {v:T} => pre{v, if (isvec{T}) elwidth{T} else 1} +} + +def get_id{op,T} = (match (op) { {_==min}=>maxvalue; {_==max}=>minvalue }){T} + +def make_scan_idem{T, op, up} = { + # Within each lane, scan using shifts by powers of 2. First k elements + # when shifting by k don't need to change, so leave them alone. + def shift{k,l} = rev{up, merge{iota{k},iota{l-k}}} + def shb{v:V, k} = { + def w=width{T}; def c = k/w + def merger{a,b} = if (up) merge{a,b} else merge{b,a} + def id = make{V, merger{c**get_id{op,T}, (width{V}/w-c)**0}} + (if (up) shl else shr){[16]u8, v, k/8} | id + } + def shb{v, k & hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}} + def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}} + def shb{v, k & k==128 & hasarch{'AVX2'}} = { + # After lanewise scan, broadcast end of lane 0 to entire lane 1 + sel{[8]i32, spread{v,up}, make{[8]i32, rev{up,3*(3