From 0a30fb309d18ac77135886bdbf6dee77689ccb24 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Sat, 12 Aug 2023 09:59:42 -0400 Subject: [PATCH] SSE2 min- and max-scans --- src/singeli/src/scan.singeli | 14 +++++++++++--- src/singeli/src/sse.singeli | 8 ++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index 566eabca..78da11ca 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -26,6 +26,7 @@ def spread{a:VT} = { # Set all elements with the last element of the input def toLast{n:VT & hasarch{'X86_64'} & w128{VT}} = { def l{v, w} = l{zipHi{v,v}, 2*w} + def l{v, w & hasarch{'SSSE3'}} = sel8{v, (16-w/8)+iota{16}%(w/8)} def l{v, w & w>=32} = shuf{[4]i32, v, 4**3} l{n, elwidth{VT}} } @@ -62,13 +63,20 @@ def prefix_byshift{op, sh} = { {v:T} => pre{v, if (isvec{T}) elwidth{T} else 1} } +def get_id{op,T} = (match (op) { {_==min}=>maxvalue; {_==max}=>minvalue }){T} + # Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈ def scan_idem = scan_scal -fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = { +fn scan_idem{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = { # Within each lane, scan using shifts by powers of 2. First k elements # when shifting by k don't need to change, so leave them alone. def shift{k,l} = merge{iota{k},iota{l-k}} - def shb{v, k} = sel8{v, shift{k/8,16}} + def shb{v:V, k} = { + def w=width{T}; def c = k/w + def id = make{V, merge{c**get_id{op,T}, (width{V}/w-c)**0}} + shl{[16]u8, v, k/8} | id + } + def shb{v, k & hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}} def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}} def shb{v, k & k==128 & hasarch{'AVX2'}} = { # After lanewise scan, broadcast end of lane 0 to entire lane 1 @@ -91,7 +99,7 @@ export{'si_scan_min_init_i32', scan_idem{i32, min}}; export{'si_scan_max_init_i3 export{'si_scan_min_init_f64', scan_idem{f64, min}}; export{'si_scan_max_init_f64', scan_idem{f64, max}} fn scan_idem_id{T, op}(x:*T, r:*T, len:u64) : void = { - scan_idem{T, op}(x, r, len, (if (same{op,min}) maxvalue else minvalue){T}) + scan_idem{T, op}(x, r, len, get_id{op, T}) } export{'si_scan_min_i8', scan_idem_id{i8 , min}}; export{'si_scan_max_i8', scan_idem_id{i8 , max}} export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', scan_idem_id{i16, max}} diff --git a/src/singeli/src/sse.singeli b/src/singeli/src/sse.singeli index b04b7a43..145ae2e5 100644 --- a/src/singeli/src/sse.singeli +++ b/src/singeli/src/sse.singeli @@ -13,10 +13,10 @@ def extract{x:T, i & w128i{T,64} & knum{i}} = emit{eltype{T}, '_mm_extract_epi64 def andAllZero{x:T, y:T & w128i{T}} = emit{u1, '_mm_testz_si128', x, y} # arith -def min{a:T,b:T & T==[16]i8 } = emit{T, '_mm_min_epi8', a, b}; def max{a:T,b:T & T==[16]i8 } = emit{T, '_mm_max_epi8', a, b} -def min{a:T,b:T & T==[ 4]i32} = emit{T, '_mm_min_epi32', a, b}; def max{a:T,b:T & T==[ 4]i32} = emit{T, '_mm_max_epi32', a, b} -def min{a:T,b:T & T==[ 8]u16} = emit{T, '_mm_min_epu16', a, b}; def max{a:T,b:T & T==[ 8]u16} = emit{T, '_mm_max_epu16', a, b} -def min{a:T,b:T & T==[ 4]u32} = emit{T, '_mm_min_epu32', a, b}; def max{a:T,b:T & T==[ 4]u32} = emit{T, '_mm_max_epu32', a, b} +def min{a:T,b:T & T==[16]i8 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epi8', a, b}; def max{a:T,b:T & T==[16]i8 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epi8', a, b} +def min{a:T,b:T & T==[ 4]i32 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epi32', a, b}; def max{a:T,b:T & T==[ 4]i32 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epi32', a, b} +def min{a:T,b:T & T==[ 8]u16 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epu16', a, b}; def max{a:T,b:T & T==[ 8]u16 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epu16', a, b} +def min{a:T,b:T & T==[ 4]u32 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epu32', a, b}; def max{a:T,b:T & T==[ 4]u32 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epu32', a, b} def __le{a:T,b:T & w128u{T}} = a==min{a,b} def __ge{a:T,b:T & w128u{T}} = a==max{a,b}