SSSE3 support in bit-table code

This commit is contained in:
Marshall Lochbaum 2023-04-30 21:58:48 -04:00
parent 464dd27a37
commit 53fb8db06f

View File

@ -3,6 +3,8 @@ if (hasarch{'AVX2'}) {
include './sse'
include './avx'
include './avx2'
} else if (hasarch{'SSSE3'}) {
include './sse'
} else if (hasarch{'X86_64'}) {
include './sse2'
} else if (hasarch{'AARCH64'}) {
@ -80,24 +82,27 @@ export{'simd_copy_ordered', copyOrdered{}}
# In-register bit table
def arch_vec{T} = [arch_defvw/width{T}]T
def TI = i8 # Table values
def VI = [32]TI
def VI = arch_vec{TI}
def simd_bittab = hasarch{'SSSE3'}
def bittab_init{tab, z} = {
@for (t in *TI~~tab over 256) t = z
}
def bittab_init{tab, z & hasarch{'AVX2'}} = {
def bittab_init{tab, z & simd_bittab} = {
init:= VI**z
@unroll (t in *VI~~tab over 256/vcount{VI}) t = init
}
def bittab_selector{loadtab} = {
def nv = vcount{VI}
{t0, t1}:= loadtab{}
low:= VI**7
hi4:= VI**(-(1<<4))
b := VI~~make{[32]u8, 1 << (iota{32} & 7)}
b := VI~~make{[nv]u8, 1 << (iota{nv} & 7)}
def selector{x} = {
top := hi4 + VI~~(([8]u32~~(x&~low))>>3)
top := hi4 + VI~~((arch_vec{u32}~~(x&~low))>>3)
byte:= sel{[16]i8, t0, hi4^top} | sel{[16]i8, t1, top}
mask:= sel{[16]i8, b, x & low}
homMask{(mask & byte) == mask}
@ -107,11 +112,13 @@ def bittab_selector{loadtab} = {
}
def readbytes{vtab}{} = {
def k = vcount{VI}; def l = 128/vcount{VI}
def side{i} = {
def m = @collect (vtab over _ from i to i+4) homMask{vtab}
VI~~make{[8]u32, merge{m,m}}
def U = arch_vec{ty_u{k}}
def m = @collect (vtab over _ from i to i+l) homMask{vtab}
VI~~make{U, if (vcount{U}>l) merge{m,m} else m}
}
each{side, 4*iota{2}}
each{side, l*iota{2}}
}
# Look up bits from table
@ -128,9 +135,10 @@ def bittab_lookup{x0:*void, n:u64, r0:*void, tab:*void} = {
x+=k; rem-=k; ++r
}
}
def bittab_lookup{x0:*void, n:u64, r0:*void, tab:*void & hasarch{'AVX2'}} = {
def bittab_lookup{x0:*void, n:u64, r0:*void, tab:*void & simd_bittab} = {
def {bitsel, _} = bittab_selector{readbytes{*VI~~tab}}
@for (x in *VI~~x0, r in *u32~~r0 over cdiv{n,32}) r = bitsel{x}
def k = vcount{VI}
@for (x in *VI~~x0, r in *ty_u{k}~~r0 over cdiv{n,k}) r = bitsel{x}
}
# Fill table with t (0 or -1) at all bytes in x0
@ -162,7 +170,7 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
def settab{T, x, i} = T~~promote{ty_s{T}, settab{x, i}}
x:= *u8~~x0
if (not hasarch{'AVX2'}) {
if (not simd_bittab) {
rem:= n
@for (i to cdiv{n,64}) {
k:= rem; if (k>64) k=64
@ -178,7 +186,8 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
# Do first few values with a scalar loop
# Avoids the cost of ever loading the table into vectors for n<=48
i:u64 = 32; if (n<=48) i=n
{rw,rv} := undef{tup{u64,u32}} # Bit results, used if rbit
def k = vcount{VI}; def uk = ty_u{k}; def ik = ty_s{k}
{rw,rv} := undef{tup{u64,uk}} # Bit results, used if rbit
if (rbit) rw = 0
@for (x over j to i) {
new:= settab{u64, x, j}
@ -191,12 +200,12 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
def {bitsel, reload_tab} = bittab_selector{readbytes{*VI~~tab}}
xv:= *VI~~x0
while (i < n) {
i0:= i; iw:= i/32
i0:= i; iw:= i/k
v:= load{xv, iw}
m:= bitsel{v} # Mask of possibly-new values
if (not match{t,0}) m^= u32~~promote{i32, t}
i+= 32
if (i > n) m&= (~u32~~0)>>((-n)%32)
if (not match{t,0}) m^= uk~~promote{ik, t}
i+= k
if (i > n) m&= (~uk~~0)>>((-n)%k)
# Any new values?
if (m == 0) {
storebit{iw, m}
@ -208,7 +217,7 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
settab1{xi, im}
if ((m&(m-1)) != 0) { # More bits than one
# Filter out values equal to the previous, or first new
def pind = (iota{32}&15) - 1
def pind = (iota{k}&15) - 1
prev:= make{VI, each{bind{max,0}, pind}}
e:= ~homMask{v == VI**TI~~xi}
e&= base{2,pind<0} | ~homMask{v == sel{[16]i8, v, prev}}
@ -216,14 +225,14 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
m&= e
while (m != 0) {
im:= i0 + ctzi{m}
new:= settab{u32, load{x, im}, im}
new:= settab{uk, load{x, im}, im}
m1:= m-1; m&= m1 # Clear low bit
if (rbit) rv&= m1 | new # Clear if not new
}
}
storebit{iw, rv}
if (u == 0) { # All bytes seen
if (rbit) @for (r in *u32~~r0 over _ from iw+1 to cdiv{n,32}) r = 0
if (rbit) @for (r in *uk~~r0 over _ from iw+1 to cdiv{n,k}) r = 0
goto{done}
}
reload_tab{}