2-byte vector binary search on 2 lanes

This commit is contained in:
Marshall Lochbaum 2023-07-07 20:27:58 -04:00
parent d19df2693a
commit fe92f91ca1

View File

@ -90,6 +90,15 @@ if (hasarch{'AVX2'}) {
}
}
# Move evens to half 0 and odds to half 1
def uninterleave{x:V & hasarch{'AVX2'}} = {
def vl = vcount{V}; def bytes = width{eltype{V}}/8
def i = 2*iota{vl/4}
def i2= join{table{+, bytes*merge{i,i+1}, iota{bytes}}}
t := V~~sel{[16]u8, to_el{u8,x}, make{[32]u8, merge{i2,i2}}}
shuf{[4]u64, t, 4b3120}
}
def rtypes = tup{i8, i16, i32, f64}
# Return index of smallest possible result type given max result value
def get_rtype{len} = {
@ -239,18 +248,35 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
off := [vl]u8**(gap - 1)
# Fill with minimum value at the beginning
def pre = (if (up) minvalue else maxvalue){T}
wv := homBlend{load{*H~~(w-gap)}, H**pre, maskOf{H,gap}}
def selw = getsel{[16]u8~~wv}
wv := homBlend{load{*V~~(w-gap)}, V**pre, maskOf{V,gap}}
# Separate even/odd elements if double width
def maxstep = lb{maxwn}
def lstep = lb{vl/2}
def has_w1 = maxstep > lstep
if (has_w1 and wn >= vl/2) wv = uninterleave{wv}
def ms{h} = getsel{[16]u8~~half{wv,h}}
def selw = ms{0}; def selw1 = if (has_w1) ms{1} else 'undef'
# Midpoint bits for each step
def lowbits = bb{copy{bytes,bytes}}
bits := each{{j} => U**(lowbits << j), iota{maxwn-1}}
@unroll (klog from 2 to lb{maxwn}+1) {
if (log==klog) @for_vec_overlap{vl} (j to xn) {
bits := each{{j} => U**(lowbits << j), iota{lstep}}
# Unroll sizes up to a full lane, handling extra lanes conditionally
# in the largest one
@unroll (klog from 2 to min{maxstep,lstep}+1) {
def last = klog==lstep
def this = if (not last) log==klog else log>=klog
if (this) @for_vec_overlap{vl} (j to xn) {
def as_u8{T,op, ...par} = T~~op{...each{bind{to_el,u8},par}}
xv:= load{*V~~(x+j), 0}
s := U**bb{iota{bytes}} # Select sequential bytes within each U
@unroll (j to klog) {
m := s | tupsel{klog-1-j, bits}
s = homBlend{m, s, lt{xv, V~~selw{to_el{u8,m}}}}
m := s | tupsel{klog-1-j,bits}
s = homBlend{m, s, lt{xv, as_u8{V,selw, m}}}
}
# Extra selection lanes
if (last and has_w1 and log>klog) {
c := lt{xv, as_u8{V,selw1, s}}
assert{T==i16} # Otherwise position of c is different
s = as_u8{U,+, s,c}; s += s
}
r := if (T==i8) s
else half{narrow{u8, s>>(lb{bytes}+wd-8)}, 0}
@ -298,8 +324,10 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param}
else if (k+1 < tuplen{rtypes}) lookup{k+1}
}
if (hasarch{'AVX2'} and T<=i16 and wn < 8 and xn >= 256/width{T}) {
bin_search_vec{T, ...param, 8}
# For >=8 i8 values, vector bit-table is as good as binary search
def wn_vec = if (T==i8) 8 else 16
if (hasarch{'AVX2'} and T<=i16 and wn < wn_vec and xn >= 256/width{T}) {
bin_search_vec{T, ...param, wn_vec}
# Lookup table threshold has to account for cost of
# populating the table (proportional to wn until it's large), and
# initializing the table (constant, much higher for i16)