Vector binary search for integer Index-of
This commit is contained in:
parent
52bca6a55c
commit
c042fe6ca3
@ -48,17 +48,22 @@ RangeFn getRange_fns[el_f64+1];
|
||||
#endif
|
||||
#if SINGELI_AVX2
|
||||
extern void (**const avx2_member_sort)(uint64_t*,void*,uint64_t,void*,uint64_t);
|
||||
extern void (**const avx2_indexOf_sort)(int8_t*,void*,uint64_t,void*,uint64_t);
|
||||
#endif
|
||||
|
||||
|
||||
#define C2i(F, W, X) C2(F, m_i32(W), X)
|
||||
extern B and_c1(B,B);
|
||||
extern B gradeDown_c1(B,B);
|
||||
extern B reverse_c1(B,B);
|
||||
extern B eq_c2(B,B,B);
|
||||
extern B ne_c2(B,B,B);
|
||||
extern B or_c2(B,B,B);
|
||||
extern B add_c2(B,B,B);
|
||||
extern B sub_c2(B,B,B);
|
||||
extern B mul_c2(B,B,B);
|
||||
extern B join_c2(B,B,B);
|
||||
extern B select_c2(B,B,B);
|
||||
|
||||
static u64 elRange(u8 eltype) { return 1ull<<(1<<elwBitLog(eltype)); }
|
||||
|
||||
@ -258,6 +263,17 @@ B indexOf_c2(B t, B w, B x) {
|
||||
return i==wia? r : C2(sub, r, C2i(mul, wia-i, C2i(eq, !w0, x)));
|
||||
}
|
||||
|
||||
#if SINGELI_AVX2
|
||||
if (xia>=32 && wia>1 && el_i8<=xe && xe<=el_i32 && wia<64>>(xe-el_i8) && we<=xe && !elChr(TI(x,elType))) {
|
||||
B g = C1(reverse, C1(gradeDown, incG(w)));
|
||||
w = C2(select, incG(g), w);
|
||||
switch (xe) { default:UD; case el_i8:w=toI8Any(w);break; case el_i16:w=toI16Any(w);break; case el_i32:w=toI32Any(w);break; }
|
||||
i8* rp; B r = m_i8arrc(&rp, x);
|
||||
avx2_indexOf_sort[xe-el_i8](rp, tyany_ptr(w), wia, tyany_ptr(x), xia);
|
||||
r = C2(select, r, C2(join, g, m_i8(wia)));
|
||||
decG(w); decG(x); return r;
|
||||
}
|
||||
#endif
|
||||
if (wia<=(we<=el_i16?4:16) && xia>16) {
|
||||
SGetU(w);
|
||||
#define XEQ(I) C2(ne, GetU(w,I), incG(x))
|
||||
|
||||
@ -202,6 +202,7 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
|
||||
# Binary search within vector registers
|
||||
def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
def up = prim != '⍒'
|
||||
def search = (prim=='∊') | (prim=='⊐')
|
||||
assert{wn > 1}; assert{wn < maxwn}
|
||||
def wd = width{T}
|
||||
def I = if (wd<32) u8 else u32; def wi = width{I}
|
||||
@ -215,7 +216,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
log := ceil_log2{wn+1}
|
||||
gap := 1<<log - wn
|
||||
# Fill with minimum value at the beginning
|
||||
def pre = if (prim=='∊') load{w} else (if (up) minvalue else maxvalue){T}
|
||||
def pre = if (search) load{w} else (if (up) minvalue else maxvalue){T}
|
||||
wg := *V~~(w-gap)
|
||||
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
|
||||
# For multiple lanes, interleave like transpose
|
||||
@ -241,7 +242,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
def selw = ms{wv}{0}; def selw1 = if (ex>=1) ms{wv}{1} else 'undef'
|
||||
def selw2 = if (ex>=2) each{ms{wv2}, iota{2}} else 'undef'
|
||||
# Offset at end
|
||||
off := U~~V**cast_i{i8, gap-1}
|
||||
off := U~~V**cast_i{i8, gap-(1-search)}
|
||||
# Midpoint bits for each step
|
||||
def lowbits = bb{copy{isub,isub}}
|
||||
bits := each{{j} => U**(lowbits << j), iota{lstep}}
|
||||
@ -263,7 +264,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
# b records if xv was found; c is added to the index
|
||||
def r_out = prim!='∊'
|
||||
def get_up{var,cmpx,use}{set,...a} = if (use) set{var, cmpx{...a}}
|
||||
b := undefined{U}; def up_b = get_up{b, eqx, prim=='∊'}
|
||||
b := undefined{U}; def up_b = get_up{b, eqx, search}
|
||||
c := undefined{U}; def up_c = get_up{c, ltx, r_out}
|
||||
up_b{=, selw, s}
|
||||
# Extra selection lanes
|
||||
@ -279,6 +280,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
}
|
||||
if (r_out) {
|
||||
r -= off
|
||||
if (prim=='⊐') r = homBlend{U**cast_i{u8,wn}, r, b}
|
||||
rn := if (T==i8) r
|
||||
else if (T==i16) half{narrow{u8, r}, 0}
|
||||
else extract{re_el{i64, narrow{u8, r}}, 0}
|
||||
@ -294,13 +296,17 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
}
|
||||
|
||||
if (hasarch{'AVX2'}) {
|
||||
fn avx2_search_bin{prim, T, maxwn}(rp:*u64, w:*void, wn:u64, x:*void, xn:u64) : void = {
|
||||
fn avx2_search_bin{prim, T, maxwn}(rp:*(if (prim=='∊') u64 else i8), w:*void, wn:u64, x:*void, xn:u64) : void = {
|
||||
bin_search_vec{prim, T, *T~~w, wn, *T~~x, xn, rp, maxwn}
|
||||
}
|
||||
exportT{
|
||||
'avx2_member_sort',
|
||||
each{avx2_search_bin{'∊',.,.}, tup{i16,i32}, tup{32,16}}
|
||||
}
|
||||
exportT{
|
||||
'avx2_indexOf_sort',
|
||||
each{avx2_search_bin{'⊐',.,.}, tup{i8,i16,i32}, tup{64,32,16}}
|
||||
}
|
||||
}
|
||||
|
||||
def unroll_sizes = tup{4,1}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user