Vector binary search Member-of
This commit is contained in:
parent
75aed91f32
commit
11245d385e
@ -46,9 +46,13 @@ RangeFn getRange_fns[el_f64+1];
|
||||
GETRANGE(i32,)
|
||||
GETRANGE(f64, if (!q_fi64(c)) return 0)
|
||||
#endif
|
||||
#if SINGELI_AVX2
|
||||
extern void (*const avx2_member_sort_i32)(uint64_t*,int32_t*,uint64_t,int32_t*,uint64_t);
|
||||
#endif
|
||||
|
||||
|
||||
#define C2i(F, W, X) C2(F, m_i32(W), X)
|
||||
extern B and_c1(B,B);
|
||||
extern B eq_c2(B,B,B);
|
||||
extern B ne_c2(B,B,B);
|
||||
extern B or_c2(B,B,B);
|
||||
@ -335,10 +339,19 @@ B memberOf_c2(B t, B w, B x) {
|
||||
}
|
||||
|
||||
u8 me = we>xe?we:xe;
|
||||
if (xia<=(me==el_i8?1:me==el_i16?4:16) && wia>16) {
|
||||
SGetU(x);
|
||||
r = WEQ(GetU(x,0));
|
||||
for (usz i=1; i<xia; i++) r = C2(or, r, WEQ(GetU(x,i)));
|
||||
if (xia<=(me==el_i8?1:me==el_i16?4:15) && wia>16) {
|
||||
#if SINGELI_AVX2
|
||||
if (we==xe && we==el_i32 && xia>1) {
|
||||
x = C1(and, x); // sort
|
||||
u64* rp; r = m_bitarrc(&rp, w);
|
||||
avx2_member_sort_i32(rp, tyany_ptr(x), xia, tyany_ptr(w), wia);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
SGetU(x);
|
||||
r = WEQ(GetU(x,0));
|
||||
for (usz i=1; i<xia; i++) r = C2(or, r, WEQ(GetU(x,i)));
|
||||
}
|
||||
decG(w); goto dec_x;
|
||||
}
|
||||
#undef WEQ
|
||||
|
||||
@ -200,7 +200,8 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
|
||||
}
|
||||
|
||||
# Binary search within vector registers
|
||||
def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
def up = prim != '⍒'
|
||||
assert{wn > 1}; assert{wn < maxwn}
|
||||
def wd = width{T}
|
||||
def I = if (wd<32) u8 else u32; def wi = width{I}
|
||||
@ -214,7 +215,7 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
log := ceil_log2{wn+1}
|
||||
gap := 1<<log - wn
|
||||
# Fill with minimum value at the beginning
|
||||
def pre = (if (up) minvalue else maxvalue){T}
|
||||
def pre = if (prim=='∊') load{w} else (if (up) minvalue else maxvalue){T}
|
||||
wg := *V~~(w-gap)
|
||||
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
|
||||
# For multiple lanes, interleave like transpose
|
||||
@ -252,33 +253,53 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
if (this) @for_vec_overlap{vl} (j to xn) {
|
||||
xv:= load{*V~~(x+j), 0}
|
||||
s := U**bb{iota{isub}} # Select sequential bytes within each U
|
||||
def ltx{se,ind} = lt{xv, V~~se{re_el{I,ind}}}
|
||||
def cmpx{cmp}{se,ind} = cmp{xv, V~~se{re_el{I,ind}}}
|
||||
def ltx = cmpx{lt}; def eqx = cmpx{==}
|
||||
@unroll (j to klog) {
|
||||
m := s | tupsel{klog-1-j,bits}
|
||||
s = homBlend{m, s, ltx{selw, m}}
|
||||
}
|
||||
r := if (isub==1) s else s>>(lb{isub}+wd-wi)
|
||||
# b records if xv was found; c is added to the index
|
||||
def r_out = prim!='∊'
|
||||
def get_up{var,cmpx,use}{set,...a} = if (use) set{var, cmpx{...a}}
|
||||
b := undefined{U}; def up_b = get_up{b, eqx, prim=='∊'}
|
||||
c := undefined{U}; def up_c = get_up{c, ltx, r_out}
|
||||
up_b{=, selw, s}
|
||||
# Extra selection lanes
|
||||
if (last and ex>=1 and log>=klog+1) {
|
||||
r += r
|
||||
c := ltx{selw1,s}
|
||||
def up_bc{set,se} = { up_b{|=,se,s}; up_c{set,se,s} }
|
||||
up_bc{=,selw1}
|
||||
if (ex>=2 and log>=klog+2) {
|
||||
r += r
|
||||
each{{se} => c += ltx{se,s}, selw2}
|
||||
each{up_bc{+=, .}, selw2}
|
||||
}
|
||||
r += c
|
||||
if (r_out) r += c
|
||||
}
|
||||
if (r_out) {
|
||||
r -= off
|
||||
rn := if (T==i8) r
|
||||
else if (T==i16) half{narrow{u8, r}, 0}
|
||||
else extract{re_el{i64, narrow{u8, r}}, 0}
|
||||
rnp := *type{rn}~~(*i8~~rp+j)
|
||||
if (isvec{type{rn}}) store{rnp, 0, rn}
|
||||
else storeu{rnp, rn}
|
||||
} else {
|
||||
out := homMask{b}; def B = type{out}; def wb = width{B}
|
||||
store{*B~~rp, cdiv{j,wb}, out>>((-j)%wb)}
|
||||
}
|
||||
r -= off
|
||||
rn := if (T==i8) r
|
||||
else if (T==i16) half{narrow{u8, r}, 0}
|
||||
else extract{re_el{i64, narrow{u8, r}}, 0}
|
||||
rnp := *type{rn}~~(*i8~~rp+j)
|
||||
if (isvec{type{rn}}) store{rnp, 0, rn}
|
||||
else storeu{rnp, rn}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasarch{'AVX2'}) {
|
||||
fn avx2_search_bin{prim, T}(rp:*u64, w:*T, wn:u64, x:*T, xn:u64) : void = {
|
||||
bin_search_vec{prim, T, w, wn, x, xn, rp, 16}
|
||||
}
|
||||
export{'avx2_member_sort_i32', avx2_search_bin{'∊',i32}}
|
||||
}
|
||||
|
||||
def unroll_sizes = tup{4,1}
|
||||
fn write{T,k}(r:*void, i:u64, ...vs:k**u64) : void = {
|
||||
each{{j,v} => store{*T~~r, i+j, cast_i{T,v}}, iota{k}, vs}
|
||||
@ -321,7 +342,7 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
|
||||
# For >=8 i8 values, vector bit-table is as good as binary search
|
||||
def wn_vec = if (T==i8) 8 else 2*256/width{T}
|
||||
if (hasarch{'AVX2'} and T<=i32 and wn < wn_vec and xn >= 256/width{T}) {
|
||||
bin_search_vec{T, ...param, wn_vec}
|
||||
bin_search_vec{if (up) '⍋' else '⍒', T, ...slice{param,1}, wn_vec}
|
||||
# Lookup table threshold has to account for cost of
|
||||
# populating the table (proportional to wn until it's large), and
|
||||
# initializing the table (constant, much higher for i16)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user