2- and 4-byte Where with unpack and table

This commit is contained in:
Marshall Lochbaum 2023-07-18 21:10:39 -04:00
parent 93e1262864
commit b0e0f210c6

View File

@ -46,15 +46,13 @@ def maketab{l,w} = {
def top = (fold{bind{flat_table,+}, l**iota{2}} - 1)%(1<<w)
top<<(l*w-w) | bot # Overlaps for all-1 value only
}
# 16-element tables
i64tab :*u32 = (maketab{4,8}*2)%(1<<32)
tab_4_16:*u64 = maketab{4,16}
if (1) {
def use_table = 1
itab :*u64 = maketab{8,8}
itab:*u64 = maketab{8,8}
} else {
def use_table = 0
}
# Recover popcount, for when POPCNT isn't there
def has_popc = hasarch{'POPCNT'}
def tab_popc{i, w} = (i>>(64-w) + 1) & (1<<w - 1)
@ -119,14 +117,28 @@ def getter{c, V, x} = {
}
}
# Top bits to convert 1-byte indices to 2 or 4
# These can only change between loop iterations, provided the
# given x for i32 is a multiple of the loop step
def topper{T, U, k, x} = {
def make_top{S} = to_el{S,U}**(if (T<i32) 0 else cast_i{S, x>>width{S}})
top := each{make_top, replicate{{S}=>S<T, tup{i8,i16}}}
def i_off = if (T<i32) 0 else { assert{x%k==0}; cast_i{u64, x/k} }
# Increment top vector when i*k passes width of bottom vector
def vb = lb{k}
def inc{i, {}} = {}
def inc{i, {t:V, ...ts}} = {
if ((i+1+i_off)%(1<<(elwidth{V}-vb)) == 0) { t += V**1; inc{i,ts} }
}
tup{top, inc}
}
def thresh{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)} = 4
fn slash{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def U = [16]u8
k1 := U**1
def X = getter{c, U, x}
def make_top{S} = to_el{S,U}**(if (T<i32) 0 else cast_i{S, x>>width{S}})
top := each{make_top, replicate{{S}=>S<T, tup{i8,i16}}}
def i_off = if (T<i32) 0 else { assert{x%16==0}; cast_i{u64, x/16} }
def {top, inctop} = topper{T, U, 16, x}
@for_special_buffered{r,16} (w in *u16~~w over i to sum) {
x := X{}
bm := make{U, 1<<(iota{16}%8)}
@ -157,12 +169,7 @@ fn slash{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)}(w:*u64, x:arg{c,T},
else each{st, iota{2}, unpack{v, tupsel{1,top}}}
each_pc{st, unpack{[16]i8~~x, tupsel{0,top}}}
}
# Increment top vector when i*16 passes width of bottom vector
def inc{{}} = {}
def inc{{t:V, ...ts}} = {
if ((i+1+i_off)%(1<<(elwidth{V}-4)) == 0) { t += V**1; inc{ts} }
}
inc{top}
inctop{i, top}
}
}
@ -218,12 +225,13 @@ fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) :
}
}
itab_4_16:*u64 = maketab{4,16}
def thresh{c==0, T==i8 & use_table} = 32
def thresh{c==0, T==i16} = 16
fn slash{c==0, T & (if (T==i8) use_table else T==i16)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def n = 64/tw
def tab = if (tw==8) itab else tab_4_16
def tab = if (tw==8) itab else itab_4_16
j:u64 = 0
def inc = base{1<<tw, n**n}
@for_special_buffered{r,8} (w in *u8~~w over sum) {
@ -240,6 +248,25 @@ fn slash{c==0, T & (if (T==i8) use_table else T==i16)}(w:*u64, x:arg{c,T}, r:*T,
}
}
def thresh{c==0, T==i16 & hasarch{'X86_64'}} = 32
def thresh{c==0, T==i32 & hasarch{'X86_64'}} = 16
fn slash{c==0, T & hasarch{'X86_64'} & i16<=T & T<=i32}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [16]i8
j := I**(if (T==i16) 0 else cast_i{i8,x})
def {top, inctop} = topper{T, I, 8, x}
@for_special_buffered{r,8} (w in *u8~~w over i to sum) {
ind := load{itab, w}
pc := popc_alt{w, ind, 8}
v := unpackLo{j + I~~make{[2]u64, ind, 0}, tupsel{0,top}}
def st{k, v:V} = store{*V~~r, k, v}
if (T==i16) st{0, v}
else each{st, iota{2}, unpack{v, tupsel{1,top}}}
r += pc
j += I**8
inctop{i, top}
}
}
def thresh{c==1, T==i8 & hasarch{'SSSE3'} & use_table} = 64
def thresh{c==1, T==i16 & hasarch{'SSSE3'} & use_table} = 32
fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'} & use_table}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
@ -257,6 +284,7 @@ fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'} & use_table}(wp:*u64, x:arg{c,T}, r
}
}
i64tab:*u32 = (maketab{4,8}*2)%(1<<32)
def thresh{c, T==i32 & hasarch{'AVX2'} & use_table} = 32
def thresh{c, T==i64 & hasarch{'AVX2'} } = 8
fn slash{c, T & hasarch{'AVX2'} & (if (T==i32) use_table else T==i64)}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {