Permutevar instead of shuffle for 4-byte vector binary search

This commit is contained in:
Marshall Lochbaum 2023-07-08 14:23:59 -04:00
parent fc57e0012d
commit 46c6d47055

View File

@ -89,6 +89,7 @@ if (hasarch{'AVX2'}) {
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop}
}
def getsel{v:V & lvec{V, 8, 32}} = { {i} => sel{V, v, i} }
}
# Move evens to half 0 and odds to half 1
@ -238,9 +239,11 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
assert{wn > 1}; assert{wn < maxwn}
def wd = width{T}
def bytes = wd/8; def bb = bind{base,256}
def vl = 256/wd
def V = [vl]T; def H = v_half{V}
def I = if (wd<32) u8 else u32; def wi = width{I}
def lanes = hasarch{'AVX2'} & (I==u8)
def isub = wd/wi; def bb = bind{base,1<<wi}
def vl = 256/wd; def svl = vl>>lanes
def V = [vl]T
def U = [vl](ty_u{T})
def lt = if (up) <; else >
# Number of steps
@ -249,29 +252,33 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
# Fill with minimum value at the beginning
def pre = (if (up) minvalue else maxvalue){T}
wg := *V~~(w-gap)
wv := homBlend{load{wg}, V**pre, maskOf{V,gap}}
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
# For multiple lanes, interleave like transpose
def maxstep = lb{maxwn}
def lstep = lb{vl/2}
def lstep = lb{svl}
def ex = maxstep - lstep
wv := if (lanes) wv0 else tup{wv0,wv0}
wv2 := wv # Compiler complains if uninitialized
if (ex>=1 and wn >= vl/2) {
wv = uninterleave{wv}
if (ex>=1 and wn >= svl) {
--gap # Allows subtracting < instead of adding <=
if (ex>=2 and wn >= vl) {
t := uninterleave{load{wg, 1}}
wv2= uninterleave{shufHalves{wv, t, 16b31}}
wv = uninterleave{shufHalves{wv, t, 16b20}}
def un = uninterleave
def tr_half{a, b} = each{bind{shufHalves,a,b}, tup{16b20, 16b31}}
def un{{a,b}} = tr_half{un{a},un{b}}
if (not lanes) tupsel{1,wv} = load{wg, 1}
wv = un{wv}
if (ex>=2 and wn >= 2*svl) {
assert{lanes} # Different transpose pattern needed
gap -= 2
tup{wv, wv2} = each{un, tr_half{wv, un{load{wg, 1}}}}
}
}
def ms{v}{h} = getsel{to_el{u8,half{v,h}}}
def ms{v}{h} = getsel{to_el{I, if (lanes) half{v,h} else tupsel{h,v}}}
def selw = ms{wv}{0}; def selw1 = if (ex>=1) ms{wv}{1} else 'undef'
def selw2 = if (ex>=2) each{ms{wv2}, iota{2}} else 'undef'
# Offset at end
off := U~~V**i8~~(gap - 1)
# Midpoint bits for each step
def lowbits = bb{copy{bytes,bytes}}
def lowbits = bb{copy{isub,isub}}
bits := each{{j} => U**(lowbits << j), iota{lstep}}
# Unroll sizes up to a full lane, handling extra lanes conditionally
# in the largest one
@ -279,15 +286,14 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def last = klog==lstep
def this = if (not last) log==klog else log>=klog
if (this) @for_vec_overlap{vl} (j to xn) {
def as_u8{T,op, ...par} = T~~op{...each{bind{to_el,u8},par}}
xv:= load{*V~~(x+j), 0}
s := U**bb{iota{bytes}} # Select sequential bytes within each U
def ltx{se,ind} = lt{xv, as_u8{V,se, ind}}
s := U**bb{iota{isub}} # Select sequential bytes within each U
def ltx{se,ind} = lt{xv, V~~se{to_el{I,ind}}}
@unroll (j to klog) {
m := s | tupsel{klog-1-j,bits}
s = homBlend{m, s, ltx{selw, m}}
}
r := if (T==i8) s else s>>(lb{bytes}+wd-8)
r := if (isub==1) s else s>>(lb{isub}+wd-wi)
# Extra selection lanes
if (last and ex>=1 and log>=klog+1) {
r += r