Vector binary search for integer Index-of

This commit is contained in:
Marshall Lochbaum 2023-11-06 15:43:26 -05:00
parent 52bca6a55c
commit c042fe6ca3
2 changed files with 26 additions and 4 deletions

View File

@ -48,17 +48,22 @@ RangeFn getRange_fns[el_f64+1];
#endif
#if SINGELI_AVX2
extern void (**const avx2_member_sort)(uint64_t*,void*,uint64_t,void*,uint64_t);
extern void (**const avx2_indexOf_sort)(int8_t*,void*,uint64_t,void*,uint64_t);
#endif
#define C2i(F, W, X) C2(F, m_i32(W), X)
extern B and_c1(B,B);
extern B gradeDown_c1(B,B);
extern B reverse_c1(B,B);
extern B eq_c2(B,B,B);
extern B ne_c2(B,B,B);
extern B or_c2(B,B,B);
extern B add_c2(B,B,B);
extern B sub_c2(B,B,B);
extern B mul_c2(B,B,B);
extern B join_c2(B,B,B);
extern B select_c2(B,B,B);
static u64 elRange(u8 eltype) { return 1ull<<(1<<elwBitLog(eltype)); }
@ -258,6 +263,17 @@ B indexOf_c2(B t, B w, B x) {
return i==wia? r : C2(sub, r, C2i(mul, wia-i, C2i(eq, !w0, x)));
}
#if SINGELI_AVX2
if (xia>=32 && wia>1 && el_i8<=xe && xe<=el_i32 && wia<64>>(xe-el_i8) && we<=xe && !elChr(TI(x,elType))) {
B g = C1(reverse, C1(gradeDown, incG(w)));
w = C2(select, incG(g), w);
switch (xe) { default:UD; case el_i8:w=toI8Any(w);break; case el_i16:w=toI16Any(w);break; case el_i32:w=toI32Any(w);break; }
i8* rp; B r = m_i8arrc(&rp, x);
avx2_indexOf_sort[xe-el_i8](rp, tyany_ptr(w), wia, tyany_ptr(x), xia);
r = C2(select, r, C2(join, g, m_i8(wia)));
decG(w); decG(x); return r;
}
#endif
if (wia<=(we<=el_i16?4:16) && xia>16) {
SGetU(w);
#define XEQ(I) C2(ne, GetU(w,I), incG(x))

View File

@ -202,6 +202,7 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
# Binary search within vector registers
def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def up = prim != '⍒'
def search = (prim=='∊') | (prim=='⊐')
assert{wn > 1}; assert{wn < maxwn}
def wd = width{T}
def I = if (wd<32) u8 else u32; def wi = width{I}
@ -215,7 +216,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
log := ceil_log2{wn+1}
gap := 1<<log - wn
# Fill with minimum value at the beginning
def pre = if (prim=='∊') load{w} else (if (up) minvalue else maxvalue){T}
def pre = if (search) load{w} else (if (up) minvalue else maxvalue){T}
wg := *V~~(w-gap)
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
# For multiple lanes, interleave like transpose
@ -241,7 +242,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def selw = ms{wv}{0}; def selw1 = if (ex>=1) ms{wv}{1} else 'undef'
def selw2 = if (ex>=2) each{ms{wv2}, iota{2}} else 'undef'
# Offset at end
off := U~~V**cast_i{i8, gap-1}
off := U~~V**cast_i{i8, gap-(1-search)}
# Midpoint bits for each step
def lowbits = bb{copy{isub,isub}}
bits := each{{j} => U**(lowbits << j), iota{lstep}}
@ -263,7 +264,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
# b records if xv was found; c is added to the index
def r_out = prim!='∊'
def get_up{var,cmpx,use}{set,...a} = if (use) set{var, cmpx{...a}}
b := undefined{U}; def up_b = get_up{b, eqx, prim=='∊'}
b := undefined{U}; def up_b = get_up{b, eqx, search}
c := undefined{U}; def up_c = get_up{c, ltx, r_out}
up_b{=, selw, s}
# Extra selection lanes
@ -279,6 +280,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
}
if (r_out) {
r -= off
if (prim=='⊐') r = homBlend{U**cast_i{u8,wn}, r, b}
rn := if (T==i8) r
else if (T==i16) half{narrow{u8, r}, 0}
else extract{re_el{i64, narrow{u8, r}}, 0}
@ -294,13 +296,17 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
}
if (hasarch{'AVX2'}) {
fn avx2_search_bin{prim, T, maxwn}(rp:*u64, w:*void, wn:u64, x:*void, xn:u64) : void = {
fn avx2_search_bin{prim, T, maxwn}(rp:*(if (prim=='∊') u64 else i8), w:*void, wn:u64, x:*void, xn:u64) : void = {
bin_search_vec{prim, T, *T~~w, wn, *T~~x, xn, rp, maxwn}
}
exportT{
'avx2_member_sort',
each{avx2_search_bin{'∊',.,.}, tup{i16,i32}, tup{32,16}}
}
exportT{
'avx2_indexOf_sort',
each{avx2_search_bin{'⊐',.,.}, tup{i8,i16,i32}, tup{64,32,16}}
}
}
def unroll_sizes = tup{4,1}