Extend 1-byte SSE2 Where to 2-byte and 4-byte with unpacked writes

This commit is contained in:
Marshall Lochbaum 2023-07-17 11:28:04 -04:00
parent 8b297ae2dc
commit fc187afdf2

View File

@ -69,7 +69,7 @@ def storeu{p:T, i, v:eltype{T} & *u64==T} = emit{void, 'storeu_u64', p+i, v}
def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p}
# Assumes w is trimmed, so the last 1 appears at index l-1
def thresh2{T} = 2
def thresh{c, T} = 2
fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1
@for (i to l) {
@ -85,18 +85,21 @@ def getter{c, V, x} = {
} else {
def k = vcount{V}
i := make{V, iota{k}}
if (eltype{V}==i32) i += V**x
if (isreg{x}) i += V**cast_i{eltype{V},x}
ii := V**k
{} => { v:=i; i+=ii; v }
}
}
def thresh2{T==i8 & hasarch{'X86_64'}} = 4
fn slash{c, T==i8 & hasarch{'X86_64'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def thresh{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)} = 4
fn slash{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def U = [16]u8
k1 := U**1
def X = getter{c, U, x}
@for_special_buffered{r,16} (w in *u16~~w to sum) {
def make_top{S} = to_el{S,U}**(if (T<i32) 0 else cast_i{S, x>>width{S}})
top := each{make_top, replicate{{S}=>S<T, tup{i8,i16}}}
def i_off = if (T<i32) 0 else { assert{x%16==0}; cast_i{u64, x/16} }
@for_special_buffered{r,16} (w in *u16~~w over i to sum) {
x := X{}
bm := make{U, 1<<(iota{16}%8)}
rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}}
@ -116,11 +119,22 @@ fn slash{c, T==i8 & hasarch{'X86_64'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64)
dif = max{dif, shr{U, dif&m, 1<<k}}
b += b
}
each{
{ins,c} => { emit{void, ins, *[8]u8~~r, x}; r+=c },
tup{'_mm_storel_pi','_mm_storeh_pi'},
pc
def each_pc{gen, ...par} = each{{...p,c} => { gen{...p}; r+=c }, ...par, pc}
if (T==i8) { # 0==tuplen{top}
def st{ins} = emit{void, ins, *[8]u8~~r, x}
each_pc{st, tup{'_mm_storel_pi','_mm_storeh_pi'}}
} else {
def st{k, v:V} = store{*V~~r, k, v}
def st{v} = if (T==i16) st{0, v}
else each{st, iota{2}, unpack{v, tupsel{1,top}}}
each_pc{st, unpack{[16]i8~~x, tupsel{0,top}}}
}
# Increment top vector when i*16 passes width of bottom vector
def inc{{}} = {}
def inc{{t:V, ...ts}} = {
if ((i+1+i_off)%(1<<(elwidth{V}-4)) == 0) { t += V**1; inc{ts} }
}
inc{top}
}
}
@ -132,8 +146,8 @@ def tab{n,l} = if (n==0) tup{0} else {
c16lut:*u64 = tab{4,16}
def vgLoad{p:T, i & T == *u64} = emit{eltype{T}, 'vg_loadLUT64', p, i}
def thresh2{T==i8 & hasarch{'BMI2'}} = 32
def thresh2{T==i16 & hasarch{'BMI2'}} = 16
def thresh{c, T==i8 & hasarch{'BMI2'}} = 32
def thresh{c, T==i16 & hasarch{'BMI2'}} = 16
fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def wt = width{T}
def b = bind{base, 1<<wt}
@ -161,7 +175,7 @@ fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : voi
}
}
def thresh2{T==i8 & hasarch{'AVX2'}} = 32
def thresh{c, T==i8 & hasarch{'AVX2'}} = 32
fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [32]i8
def S = [8]u32
@ -216,8 +230,8 @@ fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) :
itab :*u64 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<64)}}, tup{0x8080808080808080}, reverse{iota{8}}}
i64tab:*u32 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<32)}}, tup{0x80808080}, reverse{2*iota{4}}}
def thresh2{T==i32 & hasarch{'AVX2'}} = 32
def thresh2{T==i64 & hasarch{'AVX2'}} = 8
def thresh{c, T==i32 & hasarch{'AVX2'}} = 32
def thresh{c, T==i64 & hasarch{'AVX2'}} = 8
fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def V = [8]u32
@ -250,10 +264,10 @@ fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64
}
}
def thresh2{T==i8 & hasarch{'AVX512VBMI2'}} = 256
def thresh2{T==i16 & hasarch{'AVX512VBMI2'}} = 128
def thresh2{T==i32 & hasarch{'AVX512F'}} = 64
def thresh2{T==i64 & hasarch{'AVX512F'}} = 16
def thresh{c, T==i8 & hasarch{'AVX512VBMI2'}} = 256
def thresh{c, T==i16 & hasarch{'AVX512VBMI2'}} = 128
def thresh{c, T==i32 & hasarch{'AVX512F'}} = 64
def thresh{c, T==i64 & hasarch{'AVX512F'}} = 16
fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def f = fmtnat
def wt = width{T}
@ -280,12 +294,12 @@ fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64,
}
export{'si_1slash8' , slash{0, i8 }}
export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh2{i16}}
export{'si_1slash32', slash{0, i32}}; export{'si_thresh_1slash32', u64~~thresh2{i32}}
export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh2{i8 }}
export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh2{i16}}
export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh2{i32}}
export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh2{i64}}
export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh{0, i16}}
export{'si_1slash32', slash{0, i32}}; export{'si_thresh_1slash32', u64~~thresh{0, i32}}
export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh{1, i8 }}
export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh{1, i16}}
export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh{1, i32}}
export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh{1, i64}}
# pext, or boolean compress
fn pext{T}(x:T, m:T) {