Extend 1-byte SSE2 Where to 2-byte and 4-byte with unpacked writes

This commit is contained in:
Marshall Lochbaum 2023-07-17 11:28:04 -04:00
parent 8b297ae2dc
commit fc187afdf2

View File

@ -69,7 +69,7 @@ def storeu{p:T, i, v:eltype{T} & *u64==T} = emit{void, 'storeu_u64', p+i, v}
def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p} def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p}
# Assumes w is trimmed, so the last 1 appears at index l-1 # Assumes w is trimmed, so the last 1 appears at index l-1
def thresh2{T} = 2 def thresh{c, T} = 2
fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1 def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1
@for (i to l) { @for (i to l) {
@ -85,18 +85,21 @@ def getter{c, V, x} = {
} else { } else {
def k = vcount{V} def k = vcount{V}
i := make{V, iota{k}} i := make{V, iota{k}}
if (eltype{V}==i32) i += V**x if (isreg{x}) i += V**cast_i{eltype{V},x}
ii := V**k ii := V**k
{} => { v:=i; i+=ii; v } {} => { v:=i; i+=ii; v }
} }
} }
def thresh2{T==i8 & hasarch{'X86_64'}} = 4 def thresh{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)} = 4
fn slash{c, T==i8 & hasarch{'X86_64'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def U = [16]u8 def U = [16]u8
k1 := U**1 k1 := U**1
def X = getter{c, U, x} def X = getter{c, U, x}
@for_special_buffered{r,16} (w in *u16~~w to sum) { def make_top{S} = to_el{S,U}**(if (T<i32) 0 else cast_i{S, x>>width{S}})
top := each{make_top, replicate{{S}=>S<T, tup{i8,i16}}}
def i_off = if (T<i32) 0 else { assert{x%16==0}; cast_i{u64, x/16} }
@for_special_buffered{r,16} (w in *u16~~w over i to sum) {
x := X{} x := X{}
bm := make{U, 1<<(iota{16}%8)} bm := make{U, 1<<(iota{16}%8)}
rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}} rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}}
@ -116,11 +119,22 @@ fn slash{c, T==i8 & hasarch{'X86_64'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64)
dif = max{dif, shr{U, dif&m, 1<<k}} dif = max{dif, shr{U, dif&m, 1<<k}}
b += b b += b
} }
each{ def each_pc{gen, ...par} = each{{...p,c} => { gen{...p}; r+=c }, ...par, pc}
{ins,c} => { emit{void, ins, *[8]u8~~r, x}; r+=c }, if (T==i8) { # 0==tuplen{top}
tup{'_mm_storel_pi','_mm_storeh_pi'}, def st{ins} = emit{void, ins, *[8]u8~~r, x}
pc each_pc{st, tup{'_mm_storel_pi','_mm_storeh_pi'}}
} else {
def st{k, v:V} = store{*V~~r, k, v}
def st{v} = if (T==i16) st{0, v}
else each{st, iota{2}, unpack{v, tupsel{1,top}}}
each_pc{st, unpack{[16]i8~~x, tupsel{0,top}}}
} }
# Increment top vector when i*16 passes width of bottom vector
def inc{{}} = {}
def inc{{t:V, ...ts}} = {
if ((i+1+i_off)%(1<<(elwidth{V}-4)) == 0) { t += V**1; inc{ts} }
}
inc{top}
} }
} }
@ -132,8 +146,8 @@ def tab{n,l} = if (n==0) tup{0} else {
c16lut:*u64 = tab{4,16} c16lut:*u64 = tab{4,16}
def vgLoad{p:T, i & T == *u64} = emit{eltype{T}, 'vg_loadLUT64', p, i} def vgLoad{p:T, i & T == *u64} = emit{eltype{T}, 'vg_loadLUT64', p, i}
def thresh2{T==i8 & hasarch{'BMI2'}} = 32 def thresh{c, T==i8 & hasarch{'BMI2'}} = 32
def thresh2{T==i16 & hasarch{'BMI2'}} = 16 def thresh{c, T==i16 & hasarch{'BMI2'}} = 16
fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def wt = width{T} def wt = width{T}
def b = bind{base, 1<<wt} def b = bind{base, 1<<wt}
@ -161,7 +175,7 @@ fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : voi
} }
} }
def thresh2{T==i8 & hasarch{'AVX2'}} = 32 def thresh{c, T==i8 & hasarch{'AVX2'}} = 32
fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [32]i8 def I = [32]i8
def S = [8]u32 def S = [8]u32
@ -216,8 +230,8 @@ fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) :
itab :*u64 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<64)}}, tup{0x8080808080808080}, reverse{iota{8}}} itab :*u64 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<64)}}, tup{0x8080808080808080}, reverse{iota{8}}}
i64tab:*u32 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<32)}}, tup{0x80808080}, reverse{2*iota{4}}} i64tab:*u32 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<32)}}, tup{0x80808080}, reverse{2*iota{4}}}
def thresh2{T==i32 & hasarch{'AVX2'}} = 32 def thresh{c, T==i32 & hasarch{'AVX2'}} = 32
def thresh2{T==i64 & hasarch{'AVX2'}} = 8 def thresh{c, T==i64 & hasarch{'AVX2'}} = 8
fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T} def tw = width{T}
def V = [8]u32 def V = [8]u32
@ -250,10 +264,10 @@ fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64
} }
} }
def thresh2{T==i8 & hasarch{'AVX512VBMI2'}} = 256 def thresh{c, T==i8 & hasarch{'AVX512VBMI2'}} = 256
def thresh2{T==i16 & hasarch{'AVX512VBMI2'}} = 128 def thresh{c, T==i16 & hasarch{'AVX512VBMI2'}} = 128
def thresh2{T==i32 & hasarch{'AVX512F'}} = 64 def thresh{c, T==i32 & hasarch{'AVX512F'}} = 64
def thresh2{T==i64 & hasarch{'AVX512F'}} = 16 def thresh{c, T==i64 & hasarch{'AVX512F'}} = 16
fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def f = fmtnat def f = fmtnat
def wt = width{T} def wt = width{T}
@ -280,12 +294,12 @@ fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64,
} }
export{'si_1slash8' , slash{0, i8 }} export{'si_1slash8' , slash{0, i8 }}
export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh2{i16}} export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh{0, i16}}
export{'si_1slash32', slash{0, i32}}; export{'si_thresh_1slash32', u64~~thresh2{i32}} export{'si_1slash32', slash{0, i32}}; export{'si_thresh_1slash32', u64~~thresh{0, i32}}
export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh2{i8 }} export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh{1, i8 }}
export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh2{i16}} export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh{1, i16}}
export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh2{i32}} export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh{1, i32}}
export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh2{i64}} export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh{1, i64}}
# pext, or boolean compress # pext, or boolean compress
fn pext{T}(x:T, m:T) { fn pext{T}(x:T, m:T) {