diff --git a/src/singeli/src/bins.singeli b/src/singeli/src/bins.singeli index 505d91ad..0a47afe9 100644 --- a/src/singeli/src/bins.singeli +++ b/src/singeli/src/bins.singeli @@ -90,6 +90,15 @@ if (hasarch{'AVX2'}) { } } +# Move evens to half 0 and odds to half 1 +def uninterleave{x:V & hasarch{'AVX2'}} = { + def vl = vcount{V}; def bytes = width{eltype{V}}/8 + def i = 2*iota{vl/4} + def i2= join{table{+, bytes*merge{i,i+1}, iota{bytes}}} + t := V~~sel{[16]u8, to_el{u8,x}, make{[32]u8, merge{i2,i2}}} + shuf{[4]u64, t, 4b3120} +} + def rtypes = tup{i8, i16, i32, f64} # Return index of smallest possible result type given max result value def get_rtype{len} = { @@ -239,18 +248,35 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = { off := [vl]u8**(gap - 1) # Fill with minimum value at the beginning def pre = (if (up) minvalue else maxvalue){T} - wv := homBlend{load{*H~~(w-gap)}, H**pre, maskOf{H,gap}} - def selw = getsel{[16]u8~~wv} + wv := homBlend{load{*V~~(w-gap)}, V**pre, maskOf{V,gap}} + # Separate even/odd elements if double width + def maxstep = lb{maxwn} + def lstep = lb{vl/2} + def has_w1 = maxstep > lstep + if (has_w1 and wn >= vl/2) wv = uninterleave{wv} + def ms{h} = getsel{[16]u8~~half{wv,h}} + def selw = ms{0}; def selw1 = if (has_w1) ms{1} else 'undef' # Midpoint bits for each step def lowbits = bb{copy{bytes,bytes}} - bits := each{{j} => U**(lowbits << j), iota{maxwn-1}} - @unroll (klog from 2 to lb{maxwn}+1) { - if (log==klog) @for_vec_overlap{vl} (j to xn) { + bits := each{{j} => U**(lowbits << j), iota{lstep}} + # Unroll sizes up to a full lane, handling extra lanes conditionally + # in the largest one + @unroll (klog from 2 to min{maxstep,lstep}+1) { + def last = klog==lstep + def this = if (not last) log==klog else log>=klog + if (this) @for_vec_overlap{vl} (j to xn) { + def as_u8{T,op, ...par} = T~~op{...each{bind{to_el,u8},par}} xv:= load{*V~~(x+j), 0} s := U**bb{iota{bytes}} # Select sequential bytes within each U @unroll (j to klog) { - m := s | tupsel{klog-1-j, bits} - s = homBlend{m, s, lt{xv, V~~selw{to_el{u8,m}}}} + m := s | tupsel{klog-1-j,bits} + s = homBlend{m, s, lt{xv, as_u8{V,selw, m}}} + } + # Extra selection lanes + if (last and has_w1 and log>klog) { + c := lt{xv, as_u8{V,selw1, s}} + assert{T==i16} # Otherwise position of c is different + s = as_u8{U,+, s,c}; s += s } r := if (T==i8) s else half{narrow{u8, s>>(lb{bytes}+wd-8)}, 0} @@ -298,8 +324,10 @@ fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = { if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param} else if (k+1 < tuplen{rtypes}) lookup{k+1} } - if (hasarch{'AVX2'} and T<=i16 and wn < 8 and xn >= 256/width{T}) { - bin_search_vec{T, ...param, 8} + # For >=8 i8 values, vector bit-table is as good as binary search + def wn_vec = if (T==i8) 8 else 16 + if (hasarch{'AVX2'} and T<=i16 and wn < wn_vec and xn >= 256/width{T}) { + bin_search_vec{T, ...param, wn_vec} # Lookup table threshold has to account for cost of # populating the table (proportional to wn until it's large), and # initializing the table (constant, much higher for i16)