SSE min/max and pluswrap scans

This commit is contained in:
Marshall Lochbaum 2023-08-09 20:51:46 -04:00
parent f9a4a5b68c
commit e261e80168
3 changed files with 30 additions and 16 deletions

View File

@ -11,7 +11,7 @@
#define RDX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { RDX_PRE(0); RDX_PRE(1); }
#define RDX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { RDX_PRE(0); RDX_PRE(1); RDX_PRE(2); RDX_PRE(3); }
#if SINGELI_AVX2
#if SINGELI_X86_64
extern void (*const si_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const si_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 si_scan_pluswrap_u8 (c0,c0, 256,0);
@ -29,7 +29,7 @@ extern void (*const si_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,
#define RADIX_SUM_4_u32 RDX_SUM_4(u32)
#endif
#if SINGELI_AVX2 && !USZ_64
#if SINGELI_X86_64 && !USZ_64
#define RADIX_SUM_1_usz si_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz si_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz si_scan_pluswrap_u32(c0,c0,4*256,0);

View File

@ -10,8 +10,8 @@ fn scan_scal{T, op}(x:*T, r:*T, len:u64, m:T) : void = {
@for (x, r over len) r = m = op{m, x}
}
def sel8{v, t} = sel{[16]u8, v, make{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}}
def sel8{v:V, t & w256{V} & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
@ -23,13 +23,18 @@ def spread{a:VT} = {
}
# Set all elements with the last element of the input
def toLast{n:VT} = {
def toLast{n:VT & hasarch{'X86_64'} & w128{VT}} = {
def l{v, w} = l{zipHi{v,v}, 2*w}
def l{v, w & w>=32} = shuf{[4]i32, v, 4**3}
l{n, elwidth{VT}}
}
def toLast{n:VT & hasarch{'AVX2'} & w256{VT}} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7}
else shuf{[4]u64, n, 4b3333}
else shuf{[4]u64, n, 4**3}
}
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
def step = 256/width{T}
def step = arch_defvw/width{T}
def V = [step]T
p:= V**init
xv:= *V ~~ x
@ -51,7 +56,7 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
def scan_idem = scan_scal
fn scan_idem{T, op & hasarch{'AVX2'}}(x:*T, r:*T, len:u64, init:T) : void = {
fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def w = width{T}
@ -62,14 +67,19 @@ fn scan_idem{T, op & hasarch{'AVX2'}}(x:*T, r:*T, len:u64, init:T) : void = {
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
if (not hasarch{'AVX2'}) b
else op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
}
scan_post{T, init, x, r, len, op, pre}
}
fn scan_idem{T==f64, op & hasarch{'AVX2'}}(x:*T, r:*T, len:u64, init:T) : void = {
def sh{s, a} = op{a, shuf{[4]u64, a, s}}
scan_post{T, init, x, r, len, op, {a}=>sh{4b1110,sh{4b2200,a}}}
fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
def sc{a} = op{a, zipLo{a,a}}
def sc{a & hasarch{'AVX2'}} = {
def sh{s, a} = op{a, shuf{[4]u64, a, s}}
sh{4b1110,sh{4b2200,a}}
}
scan_post{T, init, x, r, len, op, sc}
}
export{'si_scan_min_init_i8', scan_idem{i8 , min}}; export{'si_scan_max_init_i8', scan_idem{i8 , max}}
@ -90,15 +100,17 @@ def scan_assoc{op, a:T} = {
def w = elwidth{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
if (not hasarch{'AVX2'}) b else {
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
}
}
def scan_plus = scan_assoc{+, .}
# Associative scan
def scan_assoc_0 = scan_scal
fn scan_assoc_0{T, op & hasarch{'AVX2'}}(x:*T, r:*T, len:u64, init:T) : void = {
fn scan_assoc_0{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Prefix op on entire AVX register
scan_post{T, init, x, r, len, op, scan_plus}
}

View File

@ -189,6 +189,8 @@ def packQ{a:T,b:T & w128i{T}} = packs{a,b}
def zipLo{a:T, b:T & w128i{T}} = emit{T, merge{'_mm_unpacklo_epi',fmtnat{elwidth{T}}}, a, b}
def zipHi{a:T, b:T & w128i{T}} = emit{T, merge{'_mm_unpackhi_epi',fmtnat{elwidth{T}}}, a, b}
def zipLo{a:T, b:T & w128f{T}} = emit{T, merge{'_mm_unpacklo_p',if (elwidth{T}==32) 's' else 'd'}, a, b}
def zipHi{a:T, b:T & w128f{T}} = emit{T, merge{'_mm_unpackhi_p',if (elwidth{T}==32) 's' else 'd'}, a, b}
def zip{a:T, b:T & w128i{T}} = tup{zipLo{a,b}, zipHi{a,b}}