This commit is contained in:
Marshall Lochbaum 2023-07-06 11:21:45 -04:00
parent 2c9e07f33d
commit 724f685a57

View File

@ -14,6 +14,7 @@ def for_backwards{vars,begin,end,iter} = {
iter{i, vars}
}
}
def for_dir{up} = if (up) for else for_backwards
def for_vec_overlap{vl}{vars,begin==0,n,iter} = {
assert{n >= vl}
def end = makelabel{}
@ -31,6 +32,20 @@ def ceil_log2{n:u64} = 64 - clz{n-1}
# Shift as u16, since x86 is missing 8-bit shifts
def shr16{v, n} = type{v}~~(([width{type{v}}/16]u16~~v) >> n)
def getsel{...x} = assert{'shuffling not supported', show{...x}}
if (hasarch{'AVX2'}) {
def getsel{h:H & lvec{H, 16, 8}} = {
v := pair{h,h}
{i} => sel{H, v, i}
}
def getsel{v:V & lvec{V, 32, 8}} = {
def H = v_half{V}
vtop := V**(vcount{V}/2)
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop}
}
}
def rtypes = tup{i8, i16, i32, f64}
# Return index of smallest possible result type given max result value
def get_rtype{len} = {
@ -78,7 +93,7 @@ fn write_indices{I,T}(t:*I, w:*T, n:u64) : void = {
fn lookup_indices{I,T,R}(tab:*I, x:*T, rp:*void, xn:u64) : void = {
def m = 1<<width{T}
t := *R~~tab; ts := t+m/2
@for (tab, t over m) t = cast_i{R, tab}
if (I!=R) @for (tab, t over m) t = cast_i{R, tab}
@for (r in *R~~rp, x in *T~~x over xn) r = load{ts, x}
}
def lookup_i8_arr = rtype_arr{bind{lookup_indices,u64,i8}}
@ -91,10 +106,6 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
def vl = 32
def V = [vl]T; def H = [vl/2]T
def U = [vl]I
def getsel{h:(H)} = {
v := pair{h,h}
{i} => sel{H, v, i}
}
if (hasarch{'AVX2'} and wn < 8) {
log := ceil_log2{wn+1}
l := 1<<log
@ -121,8 +132,7 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
t:*i8 = t0 + 128
@for (w over wn) store{t, w, 1+load{t, w}}
def plus_scan{tab, len} = {
def for_dir = if (up) for else for_backwards
s:i8=0; @for_dir (tab over len) { s += tab; tab = s }
s:i8=0; @for_dir{up} (tab over len) { s += tab; tab = s }
}
def no_bittab = makelabel{}; def done = makelabel{}
if (hasarch{'AVX2'}) {
@ -152,14 +162,8 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
def s{b} = sum4{b&bot4}
s{shr16{v,4}} + s{v}
}
# 32-byte select
vtop := V**(vl/2)
def getsel{v & width{type{v}}==256} = {
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop}
}
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
# Bit table
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
def sel_b = getsel{swap{vb}}
# Masks for filtering bit table
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1
@ -178,16 +182,16 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
if (dup) {
i0 := V~~ind # Can contain -1
ind = sel{H, ui, i0}
if (nu > 16) ind = homBlend{sel{H,ui1,i0}, ind, i0<vtop}
if (nu > 16) ind = homBlend{sel{H,ui1,i0}, ind, i0<V**(vl/2)}
}
store{*U~~(res+j), 0, ind}
}
goto{done}
setlabel{no_bittab}
}
setlabel{no_bittab}
plus_scan{t0, 256}
@for (res, x over n) res = load{t, x}
setlabel{done}
lookup_indices{T,T,T}(t0, x, *void~~res, n)
if (hasarch{'AVX2'}) setlabel{done}
}
}
@ -228,12 +232,11 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
if (T==i8 and wn<128 and xn>=32) {
bin_search_vec{up, *T~~w, wn, *T~~x, xn, *i8~~rp}
} else if (T==i8 and xn>=64 and (xn>=1024 or (xn-32) >= cast_i{u64,1}<<(ceil_log2{wn}/2+1))) {
def T = i8; def I = u64
def I = u64
t0:*I = copy{256,0}
t:*I = t0 + 128
write_indices{I,T}(t, *T~~w, wn)
def for_dir = if (up) for else for_backwards
s:I=0; @for_dir (t0 over 256) { if (t0 > s) s = t0; t0 = s }
s:I=0; @for_dir{up} (t0 over 256) { if (t0 > s) s = t0; t0 = s }
load{lookup_i8_arr,rty}(t0, *T~~x, rp, xn)
} else {
bin_search_branchless{up, *T~~w, wn, *T~~x, xn, rp, rty}