diff --git a/src/singeli/src/bins.singeli b/src/singeli/src/bins.singeli index 73b7a96a..1f2ad32d 100644 --- a/src/singeli/src/bins.singeli +++ b/src/singeli/src/bins.singeli @@ -7,9 +7,31 @@ if (hasarch{'AVX2'}) { include './mask' include 'util/tup' +def for_backwards{vars,begin,end,iter} = { + i:u64 = end + while (i > begin) { + --i + iter{i, vars} + } +} +def for_vec_overlap{vl}{vars,begin==0,n,iter} = { + assert{n >= vl} + def end = makelabel{} + j:u64 = 0 + while (1) { + iter{j, vars} + j += vl + if (j > n-vl) { if (j == n) goto{end}; j = n-vl } + } + setlabel{end} +} + def ceil_log2{n:u64} = 64 - clz{n+1} -def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = { +# Shift as u16, since x86 is missing 8-bit shifts +def shr16{v, n} = type{v}~~(([width{type{v}}/16]u16~~v) >> n) + +def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = { assert{wn > 0} def T = i8; def I = u8 def lt = if (up) <; else > @@ -17,29 +39,82 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = { def vl = 32 def V = [vl]T; def H = [vl/2]T def U = [vl]I - log := ceil_log2{wn-1} - l := 1<= vl} - while (1) { - xv:= load{*V~~(x+j), 0} - s := U**0 - h := h0 - @for (promote{u64,log}) { - s |= h &~ lt{xv, sel{H, wv, s | h}} - h = U~~(([width{V}/16]u16~~h) >> 1) # Type doesn't matter but u8 would fail - } - store{*U~~(res+j), 0, s - off} - j += vl - if (j > n-vl) { if (j == n) goto{end}; j = n-vl } + def getsel{h:(H)} = { + v := pair{h,h} + {i} => sel{H, v, i} + } + if (hasarch{'AVX2'} and wn < 16) { + log := ceil_log2{wn-1} + l := 1< V**0} + } + # Popcount on 8-bit values + def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} } + def sum4 = getsel{make{H, sums{vl/2}}} + bot4 := U**0x0f + def vpopc{v} = { + def s{b} = sum4{b&bot4} + s{shr16{v,4}} + s{v} + } + # 32-byte select + vtop := U**(vl/2) + def getsel{v & width{type{v}}==256} = { + hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}} + {i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, i127), 8**0}}} + # Exact values for multiples of 8 + store{*U~~t0, 0, vpopc{vb}} + plus_scan{t0, 256/8} + def sel_c = getsel{swap{load{*V~~t0, 0}}} + # Top 5 bits select bytes from tables; bottom 3 select from mask + bot3 := U**0x07 + @for_vec_overlap{vl} (j to n) { + xv := load{*U~~(x+j), 0} + xb := xv & bot3 + xt := shr16{xv &~ bot3, 3} + ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}} + store{*U~~(res+j), 0, ind} + } + } else { + plus_scan{t0, 256} + @for (res, x over n) res = load{t, x} + } } - setlabel{end} } def bin_search_branchless{up, w, wn, x, n, res} = { @@ -76,7 +151,7 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, xb:B) : B = { tup{r, rp} } r := undefined{B} - if (hasarch{'AVX2'} and T==i8 and wn<16 and xn>=32) { + if (T==i8 and wn<128 and xn>=32) { def {rt, rp} = alloc{i8, 'i8'}; r = rt bin_search_vec{up, *T~~w, wn, *T~~x, xn, rp} } else {