uCBQN/src/singeli/src/bins.singeli
2023-07-10 07:13:21 -04:00

80 lines
2.2 KiB
Plaintext

include './base'
if (hasarch{'AVX2'}) {
include './sse'
include './avx'
include './avx2'
}
include './mask'
include 'util/tup'
def ceil_log2{n:u64} = 64 - clz{n+1}
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8 & hasarch{'AVX2'}} = {
def T = i8; def I = u8
def lt = if (up) <; else >
def pre = (if (up) minvalue else maxvalue){T}
def vl = 32
def V = [vl]T; def H = [vl/2]T
def U = [vl]I
log := ceil_log2{wn-1}
l := 1<<log
gap := l - cast_i{u8, wn}
off := U**(gap - 1)
def double{v} = pair{v,v}
wv := double{homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}}}
h0 := U**(l/2)
j:u64 = 0
def tail = setlabel{}
while (j < n) {
xv:= load{*V~~(x+j), 0}
s := U**0
h := h0
@for (promote{u64,log}) {
s |= h &~ lt{xv, sel{H, wv, s | h}}
h = U~~(([width{V}/16]u16~~h) >> 1) # Type doesn't matter but u8 would fail
}
store{*U~~(res+j), 0, s - off}
j += vl
}
if (j != n) { j = n-vl; goto{tail} }
}
def bin_search_branchless{up, w, wn, x, n, res} = {
def lt = if (up) <; else >
ws := w - 1
l0 := wn + 1
# Take a list of indices in x/res to allow unrolling
def search{inds} = {
xs:= each{bind{load,x}, inds} # Values
ss:= each{{_}=>ws, inds} # Initial lower bound
l := l0; h := undefined{u64} # Interval size l, same for all values
while ((h=l/2) > 0) {
# Branchless update
def bin1{s, x, m} = { if (not lt{x, load{m}}) s = m }
each{bin1, ss, xs, each{bind{+,h}, ss}}
l -= h
}
each{{r,s} => store{res, r, cast_i{i32, s - ws}}, inds, ss}
}
# Unroll by 4 then 1
def search{i, k} = search{each{bind{+,i}, iota{k}}}
j:u64 = 0
def searches{k} = { while (j+k <= n) { search{j, k}; j+=k } }
each{searches, tup{4, 1}}
}
fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, r:*i32) : void = {
if (hasarch{'AVX2'} and T==i8 and wn<16 and xn>=32) {
bin_search_vec{up, *T~~w, wn, *T~~x, xn, *i8~~r}
# Slow and useless: need to allocate i8 result
j:=xn; while (j > 0) { --j; store{r, j, cast_i{i32, load{*i8~~r, j}}} }
} else {
bin_search_branchless{up, *T~~w, wn, *T~~x, xn, r}
}
}
exportT{
'si_bins',
join{table{bins, tup{i8,i16,i32,f64}, tup{1,0}}}
}