Vector binary search Member-of

This commit is contained in:
Marshall Lochbaum 2023-11-05 22:10:39 -05:00
parent 75aed91f32
commit 11245d385e
2 changed files with 52 additions and 18 deletions

View File

@ -46,9 +46,13 @@ RangeFn getRange_fns[el_f64+1];
GETRANGE(i32,)
GETRANGE(f64, if (!q_fi64(c)) return 0)
#endif
#if SINGELI_AVX2
extern void (*const avx2_member_sort_i32)(uint64_t*,int32_t*,uint64_t,int32_t*,uint64_t);
#endif
#define C2i(F, W, X) C2(F, m_i32(W), X)
extern B and_c1(B,B);
extern B eq_c2(B,B,B);
extern B ne_c2(B,B,B);
extern B or_c2(B,B,B);
@ -335,10 +339,19 @@ B memberOf_c2(B t, B w, B x) {
}
u8 me = we>xe?we:xe;
if (xia<=(me==el_i8?1:me==el_i16?4:16) && wia>16) {
SGetU(x);
r = WEQ(GetU(x,0));
for (usz i=1; i<xia; i++) r = C2(or, r, WEQ(GetU(x,i)));
if (xia<=(me==el_i8?1:me==el_i16?4:15) && wia>16) {
#if SINGELI_AVX2
if (we==xe && we==el_i32 && xia>1) {
x = C1(and, x); // sort
u64* rp; r = m_bitarrc(&rp, w);
avx2_member_sort_i32(rp, tyany_ptr(x), xia, tyany_ptr(w), wia);
} else
#endif
{
SGetU(x);
r = WEQ(GetU(x,0));
for (usz i=1; i<xia; i++) r = C2(or, r, WEQ(GetU(x,i)));
}
decG(w); goto dec_x;
}
#undef WEQ

View File

@ -200,7 +200,8 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
}
# Binary search within vector registers
def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def up = prim != '⍒'
assert{wn > 1}; assert{wn < maxwn}
def wd = width{T}
def I = if (wd<32) u8 else u32; def wi = width{I}
@ -214,7 +215,7 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
log := ceil_log2{wn+1}
gap := 1<<log - wn
# Fill with minimum value at the beginning
def pre = (if (up) minvalue else maxvalue){T}
def pre = if (prim=='∊') load{w} else (if (up) minvalue else maxvalue){T}
wg := *V~~(w-gap)
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
# For multiple lanes, interleave like transpose
@ -252,33 +253,53 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
if (this) @for_vec_overlap{vl} (j to xn) {
xv:= load{*V~~(x+j), 0}
s := U**bb{iota{isub}} # Select sequential bytes within each U
def ltx{se,ind} = lt{xv, V~~se{re_el{I,ind}}}
def cmpx{cmp}{se,ind} = cmp{xv, V~~se{re_el{I,ind}}}
def ltx = cmpx{lt}; def eqx = cmpx{==}
@unroll (j to klog) {
m := s | tupsel{klog-1-j,bits}
s = homBlend{m, s, ltx{selw, m}}
}
r := if (isub==1) s else s>>(lb{isub}+wd-wi)
# b records if xv was found; c is added to the index
def r_out = prim!='∊'
def get_up{var,cmpx,use}{set,...a} = if (use) set{var, cmpx{...a}}
b := undefined{U}; def up_b = get_up{b, eqx, prim=='∊'}
c := undefined{U}; def up_c = get_up{c, ltx, r_out}
up_b{=, selw, s}
# Extra selection lanes
if (last and ex>=1 and log>=klog+1) {
r += r
c := ltx{selw1,s}
def up_bc{set,se} = { up_b{|=,se,s}; up_c{set,se,s} }
up_bc{=,selw1}
if (ex>=2 and log>=klog+2) {
r += r
each{{se} => c += ltx{se,s}, selw2}
each{up_bc{+=, .}, selw2}
}
r += c
if (r_out) r += c
}
if (r_out) {
r -= off
rn := if (T==i8) r
else if (T==i16) half{narrow{u8, r}, 0}
else extract{re_el{i64, narrow{u8, r}}, 0}
rnp := *type{rn}~~(*i8~~rp+j)
if (isvec{type{rn}}) store{rnp, 0, rn}
else storeu{rnp, rn}
} else {
out := homMask{b}; def B = type{out}; def wb = width{B}
store{*B~~rp, cdiv{j,wb}, out>>((-j)%wb)}
}
r -= off
rn := if (T==i8) r
else if (T==i16) half{narrow{u8, r}, 0}
else extract{re_el{i64, narrow{u8, r}}, 0}
rnp := *type{rn}~~(*i8~~rp+j)
if (isvec{type{rn}}) store{rnp, 0, rn}
else storeu{rnp, rn}
}
}
}
if (hasarch{'AVX2'}) {
fn avx2_search_bin{prim, T}(rp:*u64, w:*T, wn:u64, x:*T, xn:u64) : void = {
bin_search_vec{prim, T, w, wn, x, xn, rp, 16}
}
export{'avx2_member_sort_i32', avx2_search_bin{'∊',i32}}
}
def unroll_sizes = tup{4,1}
fn write{T,k}(r:*void, i:u64, ...vs:k**u64) : void = {
each{{j,v} => store{*T~~r, i+j, cast_i{T,v}}, iota{k}, vs}
@ -321,7 +342,7 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
# For >=8 i8 values, vector bit-table is as good as binary search
def wn_vec = if (T==i8) 8 else 2*256/width{T}
if (hasarch{'AVX2'} and T<=i32 and wn < wn_vec and xn >= 256/width{T}) {
bin_search_vec{T, ...param, wn_vec}
bin_search_vec{if (up) '⍋' else '⍒', T, ...slice{param,1}, wn_vec}
# Lookup table threshold has to account for cost of
# populating the table (proportional to wn until it's large), and
# initializing the table (constant, much higher for i16)