2-byte vector binary searches

This commit is contained in:
Marshall Lochbaum 2023-07-07 08:01:55 -04:00
parent d665b90bbf
commit 1080236433

View File

@ -145,101 +145,113 @@ def bins_lookup{I, T, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = {
tfree{t0} tfree{t0}
} }
def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = { def bins_lookup{I==i8, T==i8, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = {
assert{wn > 0} assert{wn < 128} # Total must fit in i8
def T = i8; def I = u8 def T = i8
def lt = if (up) <; else >
def pre = (if (up) minvalue else maxvalue){T}
def vl = 32 def vl = 32
def V = [vl]T; def H = [vl/2]T def V = [vl]T; def H = v_half{V}
def U = [vl]I def U = [vl]u8
if (hasarch{'AVX2'} and wn < 8) { def res = *T~~rp
log := ceil_log2{wn+1}
l := 1<<log t0:*i8 = copy{256,0}
gap := l - cast_i{u8, wn} t:*i8 = t0 + 128
off := U**(gap - 1) @for (w over wn) store{t, w, 1+load{t, w}}
wv := homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}} def plus_scan{tab, len} = {
def selw = getsel{wv} s:i8=0; @for_dir{up} (tab over len) { s += tab; tab = s }
h0 := U**(l/2) }
@unroll (klog from 2 to 4) { def no_bittab = makelabel{}; def done = makelabel{}
if (log==klog) @for_vec_overlap{vl} (j to n) { if (hasarch{'AVX2'}) {
xv:= load{*V~~(x+j), 0} # Convert to bit table
s := U**0 def nb = 256/vl
h := h0 nu:u8 = 0; def addu{b} = { nu+=popc{b}; b } # Number of uniques
@unroll (klog) { vb := U~~make{[nb](ty_u{vl}),
m := s | h @collect (t in *V~~t0 over nb) addu{homMask{t > V**0}}
s = homBlend{m, s, lt{xv, selw{m}}}
h = shr16{h, 1}
}
store{*U~~(res+j), 0, s - off}
}
} }
} else { dup := promote{u64,nu} < wn
assert{wn < 128} # Total must fit in i8 # Unique index to w index conversion
t0:*i8 = copy{256,0} ui := undefined{V}; ui1 := undefined{V}
t:*i8 = t0 + 128 if (dup) {
@for (w over wn) store{t, w, 1+load{t, w}} if (nu > vl) goto{no_bittab}
def plus_scan{tab, len} = { # We'll subtract 1 when indexing so the initial 0 isn't needed
s:i8=0; @for_dir{up} (tab over len) { s += tab; tab = s } tui:*i8 = copy{vl, 0}; i:T = 0
@for (tui over promote{u64,nu}) { i += load{t, load{w, i}}; tui = i }
ui = load{*V~~tui, 0}
if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232}
ui = shuf{[4]u64, ui, 4b1010}
} }
def no_bittab = makelabel{}; def done = makelabel{} # Popcount on 8-bit values
if (hasarch{'AVX2'}) { def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
# Convert to bit table def sum4 = getsel{make{H, sums{vl/2}}}
def nb = 256/vl bot4 := U**0x0f
nu:u8 = 0; def addu{b} = { nu+=popc{b}; b } # Number of uniques def vpopc{v} = {
vb := U~~make{[nb](ty_u{vl}), def s{b} = sum4{b&bot4}
@collect (t in *V~~t0 over nb) addu{homMask{t > V**0}} s{shr16{v,4}} + s{v}
} }
dup := promote{u64,nu} < wn # Bit table
# Unique index to w index conversion def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
ui := undefined{V}; ui1 := undefined{V} def sel_b = getsel{swap{vb}}
# Masks for filtering bit table
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1
def sel_m = getsel{make{H, merge{ms - 256*(ms>127), 8**0}}}
# Exact values for multiples of 8
store{*U~~t0, 0, vpopc{vb}}
plus_scan{t0, 256/8}
def sel_c = getsel{swap{load{*V~~t0, 0} - V**dup}}
# Top 5 bits select bytes from tables; bottom 3 select from mask
bot3 := U**0x07
@for_vec_overlap{vl} (j to xn) {
xv := load{*U~~(x+j), 0}
xb := xv & bot3
xt := shr16{xv &~ bot3, 3}
ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}}
if (dup) { if (dup) {
if (nu > vl) goto{no_bittab} i0 := V~~ind # Can contain -1
# We'll subtract 1 when indexing so the initial 0 isn't needed ind = sel{H, ui, i0}
tui:*i8 = copy{vl, 0}; i:T = 0 if (nu > 16) ind = homBlend{sel{H,ui1,i0}, ind, i0<V**(vl/2)}
@for (tui over promote{u64,nu}) { i += load{t, load{w, i}}; tui = i }
ui = load{*V~~tui, 0}
if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232}
ui = shuf{[4]u64, ui, 4b1010}
} }
# Popcount on 8-bit values store{*U~~(res+j), 0, ind}
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} } }
def sum4 = getsel{make{H, sums{vl/2}}} goto{done}
bot4 := U**0x0f setlabel{no_bittab}
def vpopc{v} = { }
def s{b} = sum4{b&bot4} plus_scan{t0, 256}
s{shr16{v,4}} + s{v} @for (res, x over xn) res = load{t, x}
} if (hasarch{'AVX2'}) setlabel{done}
# Bit table }
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
def sel_b = getsel{swap{vb}} def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
# Masks for filtering bit table assert{wn > 1}
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1 def wd = width{T}
def sel_m = getsel{make{H, merge{ms - 256*(ms>127), 8**0}}} def bytes = wd/8; def bb = bind{base,256}
# Exact values for multiples of 8 def vl = 256/wd
store{*U~~t0, 0, vpopc{vb}} def V = [vl]T; def H = v_half{V}
plus_scan{t0, 256/8} def U = [vl](ty_u{T})
def sel_c = getsel{swap{load{*V~~t0, 0} - V**dup}} def lt = if (up) <; else >
# Top 5 bits select bytes from tables; bottom 3 select from mask # Number of steps
bot3 := U**0x07 log := ceil_log2{wn+1}
@for_vec_overlap{vl} (j to n) { l := 1<<log
xv := load{*U~~(x+j), 0} gap := l - cast_i{u8, wn}
xb := xv & bot3 off := [vl]u8**(gap - 1)
xt := shr16{xv &~ bot3, 3} # Fill with minimum value at the beginning
ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}} def pre = (if (up) minvalue else maxvalue){T}
if (dup) { wv := homBlend{load{*H~~(w-gap), 0}, H**pre, maskOf{H,gap}}
i0 := V~~ind # Can contain -1 def selw = getsel{[16]u8~~wv}
ind = sel{H, ui, i0} # A bit in every byte
if (nu > 16) ind = homBlend{sel{H,ui1,i0}, ind, i0<V**(vl/2)} h0 := U**(bb{copy{bytes,bytes}} * (cast_i{ty_u{T},l}/2))
} @unroll (klog from 2 to lb{maxwn}+1) {
store{*U~~(res+j), 0, ind} if (log==klog) @for_vec_overlap{vl} (j to xn) {
} xv:= load{*V~~(x+j), 0}
goto{done} s := U**bb{iota{bytes}} # Select sequential bytes within each U
setlabel{no_bittab} h := h0
@unroll (klog) {
m := s | h
s = homBlend{m, s, lt{xv, V~~selw{to_el{u8,m}}}}
h = shr16{h, 1}
}
r := if (T==i8) s
else half{narrow{u8, s>>(lb{bytes}+wd-8)}, 0}
store{*[vl]i8~~(*i8~~rp+j), 0, r - off}
} }
plus_scan{t0, 256}
@for (res, x over n) res = load{t, x}
if (hasarch{'AVX2'}) setlabel{done}
} }
} }
@ -282,12 +294,13 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param} if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param}
else if (k+1 < tuplen{rtypes}) lookup{k+1} else if (k+1 < tuplen{rtypes}) lookup{k+1}
} }
if (hasarch{'AVX2'} and T<=i16 and wn < 8 and xn >= 256/width{T}) {
bin_search_vec{T, ...param, 8}
# Lookup table threshold has to account for cost of # Lookup table threshold has to account for cost of
# populating the table (proportional to wn until it's large), and # populating the table (proportional to wn until it's large), and
# initializing the table (constant, much higher for i16) # initializing the table (constant, much higher for i16)
if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) { } else if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) {
if (rty==0) bin_search_vec{...slice{param,0,-1}, *i8~~rp} lookup{0}
else lookup{1}
} else if (T==i16 and xn>=512 and (xn>=1<<14 or xn >= wn>>6 + (u64~~3<<(12+rty))/promote{u64,ceil_log2{wn}+2})) { } else if (T==i16 and xn>=512 and (xn>=1<<14 or xn >= wn>>6 + (u64~~3<<(12+rty))/promote{u64,ceil_log2{wn}+2})) {
lookup{0} lookup{0}
} else { } else {