Table-based 1-byte Bins implementations, including AVX2

This commit is contained in:
Marshall Lochbaum 2023-07-04 21:38:08 -04:00
parent 7161689196
commit ed9e8b4057

View File

@ -7,9 +7,31 @@ if (hasarch{'AVX2'}) {
include './mask'
include 'util/tup'
def for_backwards{vars,begin,end,iter} = {
i:u64 = end
while (i > begin) {
--i
iter{i, vars}
}
}
def for_vec_overlap{vl}{vars,begin==0,n,iter} = {
assert{n >= vl}
def end = makelabel{}
j:u64 = 0
while (1) {
iter{j, vars}
j += vl
if (j > n-vl) { if (j == n) goto{end}; j = n-vl }
}
setlabel{end}
}
def ceil_log2{n:u64} = 64 - clz{n+1}
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = {
# Shift as u16, since x86 is missing 8-bit shifts
def shr16{v, n} = type{v}~~(([width{type{v}}/16]u16~~v) >> n)
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
assert{wn > 0}
def T = i8; def I = u8
def lt = if (up) <; else >
@ -17,29 +39,82 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = {
def vl = 32
def V = [vl]T; def H = [vl/2]T
def U = [vl]I
log := ceil_log2{wn-1}
l := 1<<log
gap := l - cast_i{u8, wn}
off := U**(gap - 1)
def double{v} = pair{v,v}
wv := double{homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}}}
h0 := U**(l/2)
j:u64 = 0
def end = makelabel{}
assert{n >= vl}
while (1) {
xv:= load{*V~~(x+j), 0}
s := U**0
h := h0
@for (promote{u64,log}) {
s |= h &~ lt{xv, sel{H, wv, s | h}}
h = U~~(([width{V}/16]u16~~h) >> 1) # Type doesn't matter but u8 would fail
}
store{*U~~(res+j), 0, s - off}
j += vl
if (j > n-vl) { if (j == n) goto{end}; j = n-vl }
def getsel{h:(H)} = {
v := pair{h,h}
{i} => sel{H, v, i}
}
if (hasarch{'AVX2'} and wn < 16) {
log := ceil_log2{wn-1}
l := 1<<log
gap := l - cast_i{u8, wn}
off := U**(gap - 1)
wv := homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}}
def selw = getsel{wv}
h0 := U**(l/2)
@for_vec_overlap{vl} (j to n) {
xv:= load{*V~~(x+j), 0}
s := U**0
h := h0
@for (promote{u64,log}) {
s |= h &~ lt{xv, selw{s | h}}
h = shr16{h, 1}
}
store{*U~~(res+j), 0, s - off}
}
} else {
assert{wn < 128} # Total must fit in i8
t0:*i8 = copy{256,0}
t:*i8 = t0 + 128
@for (w over wn) store{t, w, 1+load{t, w}}
def plus_scan{tab, len} = {
def for_dir = if (up) for else for_backwards
s:i8=0; @for_dir (tab over len) { s += tab; tab = s }
}
def getm{} = { m:=V**0; @unroll (v in *V~~t0 over 256/vl) m = max{m,v}; m }
if (hasarch{'AVX2'} and homAll{getm{} <= V**1}) {
# Convert to bit table
def nb = 256/vl
vb := U~~make{[nb](ty_u{vl}),
@collect (t in *V~~t0 over nb) homMask{t > V**0}
}
# Popcount on 8-bit values
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
def sum4 = getsel{make{H, sums{vl/2}}}
bot4 := U**0x0f
def vpopc{v} = {
def s{b} = sum4{b&bot4}
s{shr16{v,4}} + s{v}
}
# 32-byte select
vtop := U**(vl/2)
def getsel{v & width{type{v}}==256} = {
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, i<vtop}
}
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
# Bit table
def sel_b = getsel{swap{vb}}
# Masks for filtering bit table
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1
def sel_m = getsel{make{H, merge{ms - 256*(ms>127), 8**0}}}
# Exact values for multiples of 8
store{*U~~t0, 0, vpopc{vb}}
plus_scan{t0, 256/8}
def sel_c = getsel{swap{load{*V~~t0, 0}}}
# Top 5 bits select bytes from tables; bottom 3 select from mask
bot3 := U**0x07
@for_vec_overlap{vl} (j to n) {
xv := load{*U~~(x+j), 0}
xb := xv & bot3
xt := shr16{xv &~ bot3, 3}
ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}}
store{*U~~(res+j), 0, ind}
}
} else {
plus_scan{t0, 256}
@for (res, x over n) res = load{t, x}
}
}
setlabel{end}
}
def bin_search_branchless{up, w, wn, x, n, res} = {
@ -76,7 +151,7 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, xb:B) : B = {
tup{r, rp}
}
r := undefined{B}
if (hasarch{'AVX2'} and T==i8 and wn<16 and xn>=32) {
if (T==i8 and wn<128 and xn>=32) {
def {rt, rp} = alloc{i8, 'i8'}; r = rt
bin_search_vec{up, *T~~w, wn, *T~~x, xn, rp}
} else {