Table-based 2-byte Bins, using max-scan

This commit is contained in:
Marshall Lochbaum 2023-07-06 17:54:43 -04:00
parent 724f685a57
commit a711eb72eb

View File

@ -32,6 +32,50 @@ def ceil_log2{n:u64} = 64 - clz{n-1}
# Shift as u16, since x86 is missing 8-bit shifts
def shr16{v, n} = type{v}~~(([width{type{v}}/16]u16~~v) >> n)
# Forward or backwards in-place max-scan
# Assumes a whole number of vectors and minimum 0
fn max_scan{T, up}(x:*T, len:u64) : void = {
def w = width{T}
if (hasarch{'AVX2'} and T!=u64) {
def op = max
# TODO unify with scan.singeli avx2_scan_idem
def rev{a} = if (up) a else (tuplen{a}-1)-reverse{a}
def maker{T, l} = make{T, rev{l}}
def sel8{v, t} = sel{[16]u8, v, maker{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,rev{n}}}
def spread{a:VT} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
op{b, sel{[8]i32, spread{b}, maker{[8]i32, 3*(3<iota{8})}}}
}
def toLast{n:VT} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**(up*7)}
else shuf{[4]u64, n, up*4b3333}
}
def vl = 256/w
def V = [vl]T
p := V**0
@for_dir{up} (v in *V~~x over len/vl) { v = op{pre{v}, p}; p = toLast{v} }
} else {
m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m }
}
}
def fmt_type{T} = {
def w = width{T}
merge{quality{T}, if (w==8) '8' else if (w==16) '16' else if (w==32) '32' else '64'}
}
def talloc{T, len} = emit{*T, 'TALLOCP', fmt_type{T}, len}
def tfree{ptr} = emit{void, 'TFREE', ptr}
def getsel{...x} = assert{'shuffling not supported', show{...x}}
if (hasarch{'AVX2'}) {
def getsel{h:H & lvec{H, 16, 8}} = {
@ -90,13 +134,16 @@ fn write_indices{I,T}(t:*I, w:*T, n:u64) : void = {
}
setlabel{break}
}
fn lookup_indices{I,T,R}(tab:*I, x:*T, rp:*void, xn:u64) : void = {
def m = 1<<width{T}
t := *R~~tab; ts := t+m/2
if (I!=R) @for (tab, t over m) t = cast_i{R, tab}
@for (r in *R~~rp, x in *T~~x over xn) r = load{ts, x}
def bins_lookup{I, T, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = {
def tc = 1<<width{T}
t0:*I = talloc{I, tc}
@for (t0 over tc) t0 = 0
t:*I = t0 + tc/2
write_indices{I,T}(t, *T~~w, wn)
max_scan{I, up}(t0, tc)
@for (r in *I~~rp, x over xn) r = load{t, x}
tfree{t0}
}
def lookup_i8_arr = rtype_arr{bind{lookup_indices,u64,i8}}
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
assert{wn > 0}
@ -190,7 +237,7 @@ def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = {
setlabel{no_bittab}
}
plus_scan{t0, 256}
lookup_indices{T,T,T}(t0, x, *void~~res, n)
@for (res, x over n) res = load{t, x}
if (hasarch{'AVX2'}) setlabel{done}
}
}
@ -229,17 +276,21 @@ def bin_search_branchless{up, w, wn, x, n, res, rtype} = {
}
fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
if (T==i8 and wn<128 and xn>=32) {
bin_search_vec{up, *T~~w, wn, *T~~x, xn, *i8~~rp}
} else if (T==i8 and xn>=64 and (xn>=1024 or (xn-32) >= cast_i{u64,1}<<(ceil_log2{wn}/2+1))) {
def I = u64
t0:*I = copy{256,0}
t:*I = t0 + 128
write_indices{I,T}(t, *T~~w, wn)
s:I=0; @for_dir{up} (t0 over 256) { if (t0 > s) s = t0; t0 = s }
load{lookup_i8_arr,rty}(t0, *T~~x, rp, xn)
def param = tup{up, *T~~w, wn, *T~~x, xn, rp}
def lookup{k} = {
if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param}
else if (k+1 < tuplen{rtypes}) lookup{k+1}
}
# Lookup table threshold has to account for cost of
# populating the table (proportional to wn until it's large), and
# initializing the table (constant, much higher for i16)
if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) {
if (rty==0) bin_search_vec{...slice{param,0,-1}, *i8~~rp}
else lookup{1}
} else if (T==i16 and xn>=512 and (xn>=1<<14 or xn >= wn>>6 + (u64~~3<<(12+rty))/promote{u64,ceil_log2{wn}+2})) {
lookup{0}
} else {
bin_search_branchless{up, *T~~w, wn, *T~~x, xn, rp, rty}
bin_search_branchless{...param, rty}
}
}