From 1080236433f8c66c1a1ce5769f838e65a200864b Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 7 Jul 2023 08:01:55 -0400 Subject: [PATCH] 2-byte vector binary searches --- src/singeli/src/bins.singeli | 197 +++++++++++++++++++---------------- 1 file changed, 105 insertions(+), 92 deletions(-) diff --git a/src/singeli/src/bins.singeli b/src/singeli/src/bins.singeli index c475e46d..7244f470 100644 --- a/src/singeli/src/bins.singeli +++ b/src/singeli/src/bins.singeli @@ -145,101 +145,113 @@ def bins_lookup{I, T, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = { tfree{t0} } -def bin_search_vec{up, w:*i8, wn, x:*i8, n, res:*i8} = { - assert{wn > 0} - def T = i8; def I = u8 - def lt = if (up) <; else > - def pre = (if (up) minvalue else maxvalue){T} +def bins_lookup{I==i8, T==i8, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = { + assert{wn < 128} # Total must fit in i8 + def T = i8 def vl = 32 - def V = [vl]T; def H = [vl/2]T - def U = [vl]I - if (hasarch{'AVX2'} and wn < 8) { - log := ceil_log2{wn+1} - l := 1< V**0}} } - } else { - assert{wn < 128} # Total must fit in i8 - t0:*i8 = copy{256,0} - t:*i8 = t0 + 128 - @for (w over wn) store{t, w, 1+load{t, w}} - def plus_scan{tab, len} = { - s:i8=0; @for_dir{up} (tab over len) { s += tab; tab = s } + dup := promote{u64,nu} < wn + # Unique index to w index conversion + ui := undefined{V}; ui1 := undefined{V} + if (dup) { + if (nu > vl) goto{no_bittab} + # We'll subtract 1 when indexing so the initial 0 isn't needed + tui:*i8 = copy{vl, 0}; i:T = 0 + @for (tui over promote{u64,nu}) { i += load{t, load{w, i}}; tui = i } + ui = load{*V~~tui, 0} + if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232} + ui = shuf{[4]u64, ui, 4b1010} } - def no_bittab = makelabel{}; def done = makelabel{} - if (hasarch{'AVX2'}) { - # Convert to bit table - def nb = 256/vl - nu:u8 = 0; def addu{b} = { nu+=popc{b}; b } # Number of uniques - vb := U~~make{[nb](ty_u{vl}), - @collect (t in *V~~t0 over nb) addu{homMask{t > V**0}} - } - dup := promote{u64,nu} < wn - # Unique index to w index conversion - ui := undefined{V}; ui1 := undefined{V} + # Popcount on 8-bit values + def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} } + def sum4 = getsel{make{H, sums{vl/2}}} + bot4 := U**0x0f + def vpopc{v} = { + def s{b} = sum4{b&bot4} + s{shr16{v,4}} + s{v} + } + # Bit table + def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness + def sel_b = getsel{swap{vb}} + # Masks for filtering bit table + def ms = if (up) 256-(1<<(1+iota{8})) else (1<127), 8**0}}} + # Exact values for multiples of 8 + store{*U~~t0, 0, vpopc{vb}} + plus_scan{t0, 256/8} + def sel_c = getsel{swap{load{*V~~t0, 0} - V**dup}} + # Top 5 bits select bytes from tables; bottom 3 select from mask + bot3 := U**0x07 + @for_vec_overlap{vl} (j to xn) { + xv := load{*U~~(x+j), 0} + xb := xv & bot3 + xt := shr16{xv &~ bot3, 3} + ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}} if (dup) { - if (nu > vl) goto{no_bittab} - # We'll subtract 1 when indexing so the initial 0 isn't needed - tui:*i8 = copy{vl, 0}; i:T = 0 - @for (tui over promote{u64,nu}) { i += load{t, load{w, i}}; tui = i } - ui = load{*V~~tui, 0} - if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232} - ui = shuf{[4]u64, ui, 4b1010} + i0 := V~~ind # Can contain -1 + ind = sel{H, ui, i0} + if (nu > 16) ind = homBlend{sel{H,ui1,i0}, ind, i0127), 8**0}}} - # Exact values for multiples of 8 - store{*U~~t0, 0, vpopc{vb}} - plus_scan{t0, 256/8} - def sel_c = getsel{swap{load{*V~~t0, 0} - V**dup}} - # Top 5 bits select bytes from tables; bottom 3 select from mask - bot3 := U**0x07 - @for_vec_overlap{vl} (j to n) { - xv := load{*U~~(x+j), 0} - xb := xv & bot3 - xt := shr16{xv &~ bot3, 3} - ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}} - if (dup) { - i0 := V~~ind # Can contain -1 - ind = sel{H, ui, i0} - if (nu > 16) ind = homBlend{sel{H,ui1,i0}, ind, i0 1} + def wd = width{T} + def bytes = wd/8; def bb = bind{base,256} + def vl = 256/wd + def V = [vl]T; def H = v_half{V} + def U = [vl](ty_u{T}) + def lt = if (up) <; else > + # Number of steps + log := ceil_log2{wn+1} + l := 1<>(lb{bytes}+wd-8)}, 0} + store{*[vl]i8~~(*i8~~rp+j), 0, r - off} } - plus_scan{t0, 256} - @for (res, x over n) res = load{t, x} - if (hasarch{'AVX2'}) setlabel{done} } } @@ -282,12 +294,13 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = { if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param} else if (k+1 < tuplen{rtypes}) lookup{k+1} } + if (hasarch{'AVX2'} and T<=i16 and wn < 8 and xn >= 256/width{T}) { + bin_search_vec{T, ...param, 8} # Lookup table threshold has to account for cost of # populating the table (proportional to wn until it's large), and # initializing the table (constant, much higher for i16) - if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) { - if (rty==0) bin_search_vec{...slice{param,0,-1}, *i8~~rp} - else lookup{1} + } else if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) { + lookup{0} } else if (T==i16 and xn>=512 and (xn>=1<<14 or xn >= wn>>6 + (u64~~3<<(12+rty))/promote{u64,ceil_log2{wn}+2})) { lookup{0} } else {