Table-based 1-byte Bins implementations, including AVX2
This commit is contained in:
parent
7161689196
commit
ed9e8b4057
@ -7,9 +7,31 @@ if (hasarch{'AVX2'}) {
|
||||
include './mask'
|
||||
include 'util/tup'
|
||||
|
||||
def for_backwards{vars,begin,end,iter} = {
|
||||
i:u64 = end
|
||||
while (i > begin) {
|
||||
--i
|
||||
iter{i, vars}
|
||||
}
|
||||
}
|
||||
def for_vec_overlap{vl}{vars,begin==0,n,iter} = {
|
||||
assert{n >= vl}
|
||||
def end = makelabel{}
|
||||
j:u64 = 0
|
||||
while (1) {
|
||||
iter{j, vars}
|
||||
j += vl
|
||||
if (j > n-vl) { if (j == n) goto{end}; j = n-vl }
|
||||
}
|
||||
setlabel{end}
|
||||
}
|
||||
|
||||
def ceil_log2{n:u64} = 64 - clz{n+1}
|
||||
|
||||
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = {
|
||||
# Shift as u16, since x86 is missing 8-bit shifts
|
||||
def shr16{v, n} = type{v}~~(([width{type{v}}/16]u16~~v) >> n)
|
||||
|
||||
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
|
||||
assert{wn > 0}
|
||||
def T = i8; def I = u8
|
||||
def lt = if (up) <; else >
|
||||
@ -17,29 +39,82 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = {
|
||||
def vl = 32
|
||||
def V = [vl]T; def H = [vl/2]T
|
||||
def U = [vl]I
|
||||
log := ceil_log2{wn-1}
|
||||
l := 1<<log
|
||||
gap := l - cast_i{u8, wn}
|
||||
off := U**(gap - 1)
|
||||
def double{v} = pair{v,v}
|
||||
wv := double{homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}}}
|
||||
h0 := U**(l/2)
|
||||
j:u64 = 0
|
||||
def end = makelabel{}
|
||||
assert{n >= vl}
|
||||
while (1) {
|
||||
xv:= load{*V~~(x+j), 0}
|
||||
s := U**0
|
||||
h := h0
|
||||
@for (promote{u64,log}) {
|
||||
s |= h &~ lt{xv, sel{H, wv, s | h}}
|
||||
h = U~~(([width{V}/16]u16~~h) >> 1) # Type doesn't matter but u8 would fail
|
||||
}
|
||||
store{*U~~(res+j), 0, s - off}
|
||||
j += vl
|
||||
if (j > n-vl) { if (j == n) goto{end}; j = n-vl }
|
||||
def getsel{h:(H)} = {
|
||||
v := pair{h,h}
|
||||
{i} => sel{H, v, i}
|
||||
}
|
||||
if (hasarch{'AVX2'} and wn < 16) {
|
||||
log := ceil_log2{wn-1}
|
||||
l := 1<<log
|
||||
gap := l - cast_i{u8, wn}
|
||||
off := U**(gap - 1)
|
||||
wv := homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}}
|
||||
def selw = getsel{wv}
|
||||
h0 := U**(l/2)
|
||||
@for_vec_overlap{vl} (j to n) {
|
||||
xv:= load{*V~~(x+j), 0}
|
||||
s := U**0
|
||||
h := h0
|
||||
@for (promote{u64,log}) {
|
||||
s |= h &~ lt{xv, selw{s | h}}
|
||||
h = shr16{h, 1}
|
||||
}
|
||||
store{*U~~(res+j), 0, s - off}
|
||||
}
|
||||
} else {
|
||||
assert{wn < 128} # Total must fit in i8
|
||||
t0:*i8 = copy{256,0}
|
||||
t:*i8 = t0 + 128
|
||||
@for (w over wn) store{t, w, 1+load{t, w}}
|
||||
def plus_scan{tab, len} = {
|
||||
def for_dir = if (up) for else for_backwards
|
||||
s:i8=0; @for_dir (tab over len) { s += tab; tab = s }
|
||||
}
|
||||
def getm{} = { m:=V**0; @unroll (v in *V~~t0 over 256/vl) m = max{m,v}; m }
|
||||
if (hasarch{'AVX2'} and homAll{getm{} <= V**1}) {
|
||||
# Convert to bit table
|
||||
def nb = 256/vl
|
||||
vb := U~~make{[nb](ty_u{vl}),
|
||||
@collect (t in *V~~t0 over nb) homMask{t > V**0}
|
||||
}
|
||||
# Popcount on 8-bit values
|
||||
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
|
||||
def sum4 = getsel{make{H, sums{vl/2}}}
|
||||
bot4 := U**0x0f
|
||||
def vpopc{v} = {
|
||||
def s{b} = sum4{b&bot4}
|
||||
s{shr16{v,4}} + s{v}
|
||||
}
|
||||
# 32-byte select
|
||||
vtop := U**(vl/2)
|
||||
def getsel{v & width{type{v}}==256} = {
|
||||
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
|
||||
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, i<vtop}
|
||||
}
|
||||
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
|
||||
# Bit table
|
||||
def sel_b = getsel{swap{vb}}
|
||||
# Masks for filtering bit table
|
||||
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1
|
||||
def sel_m = getsel{make{H, merge{ms - 256*(ms>127), 8**0}}}
|
||||
# Exact values for multiples of 8
|
||||
store{*U~~t0, 0, vpopc{vb}}
|
||||
plus_scan{t0, 256/8}
|
||||
def sel_c = getsel{swap{load{*V~~t0, 0}}}
|
||||
# Top 5 bits select bytes from tables; bottom 3 select from mask
|
||||
bot3 := U**0x07
|
||||
@for_vec_overlap{vl} (j to n) {
|
||||
xv := load{*U~~(x+j), 0}
|
||||
xb := xv & bot3
|
||||
xt := shr16{xv &~ bot3, 3}
|
||||
ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}}
|
||||
store{*U~~(res+j), 0, ind}
|
||||
}
|
||||
} else {
|
||||
plus_scan{t0, 256}
|
||||
@for (res, x over n) res = load{t, x}
|
||||
}
|
||||
}
|
||||
setlabel{end}
|
||||
}
|
||||
|
||||
def bin_search_branchless{up, w, wn, x, n, res} = {
|
||||
@ -76,7 +151,7 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, xb:B) : B = {
|
||||
tup{r, rp}
|
||||
}
|
||||
r := undefined{B}
|
||||
if (hasarch{'AVX2'} and T==i8 and wn<16 and xn>=32) {
|
||||
if (T==i8 and wn<128 and xn>=32) {
|
||||
def {rt, rp} = alloc{i8, 'i8'}; r = rt
|
||||
bin_search_vec{up, *T~~w, wn, *T~~x, xn, rp}
|
||||
} else {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user