Merge pull request #86 from mlochbaum/bins

Bins
This commit is contained in:
dzaima 2023-07-10 16:26:50 +03:00 committed by GitHub
commit 368e6b6001
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 593 additions and 22 deletions

View File

@ -634,6 +634,7 @@ cachedBin‿linkerCache ← {
"xa.""src/builtins/squeeze.c""squeeze", "xa.""src/utils/mut.c""copy",
"xa.""src/utils/bits.c""bits", "xag""src/builtins/transpose.c""transpose",
"xag""src/builtins/search.c""search", "xa.""src/builtins/fold.c""fold",
"xag""src/builtins/sort.c""bins"
"2..""src/builtins/select.c""select", "2..""src/builtins/scan.c""scan",
"2..""src/builtins/slash.c""constrep", "2..""src/builtins/scan.c""neq",

View File

@ -19,15 +19,31 @@
// SHOULD widen odd cell sizes under 8 bytes in sort and grade
// Bins
// Length 0 or 1 𝕨: trivial, or comparison
// Stand-alone 𝕨 sortedness check
// SHOULD vectorize sortedness check on lists of numbers
// Mixed integer and character arguments gives all 0 or ≠𝕨
// Integers and characters: 4-byte branchless binary search
// Non-Singeli, integers and characters:
// 4-byte branchless binary search, 4-byte output
// SHOULD support fast character searches
// Boolean 𝕨 or 𝕩: lookup table (single binary search on boolean 𝕨)
// Different widths: generally widen narrower argument
// Narrow wider-type 𝕩 instead if it isn't much shorter
// SHOULD trim wider-type 𝕨 and possibly narrow
// Same-width numbers:
// Output type based on ≠𝕨
// Short 𝕨: vector binary search (then linear on extra lanes)
// 1- or 2-byte type, long enough 𝕩: lookup table from ⌈`
// Binary gallops to skip long repeated elements of 𝕨
// 1-byte, no duplicates or few uniques: vector bit-table lookup
// General: interleaved branchless binary search
// COULD start interleaved search with a vector binary round
// General case: branching binary search
// SHOULD implement f64 branchless binary search
// SHOULD interleave multiple branchless binary searches
// SHOULD specialize bins on equal types at least
// SHOULD implement table-based ⍋⍒ for small-range 𝕨
// SHOULD special-case short 𝕨
// COULD trim 𝕨 based on range of 𝕩
// COULD optimize small-range 𝕨 with small-type methods
// SHOULD partition 𝕩 when 𝕨 is large
// COULD interpolation search for large 𝕩 and short 𝕨
// COULD use linear search and galloping for sorted 𝕩
#define GRADE_CAT(N) CAT(GRADE_UD(gradeUp,gradeDown),N)
#define GRADE_NEG GRADE_UD(,-)
@ -326,6 +342,62 @@ B GRADE_CAT(c1)(B t, B x) {
}
bool CAT(isSorted,GRADE_UD(Up,Down))(B x) {
assert(isArr(x) && RNK(x)==1); // TODO extend to >=1
usz xia = IA(x);
if (xia <= 1) return 1;
#define CMP(TEST) \
for (usz i=1; i<xia; i++) if (TEST) return 0; \
return 1;
#define CASE(T) case el_##T: { \
T* xp = T##any_ptr(x); CMP(xp[i-1] GRADE_UD(>,<) xp[i]) }
switch (TI(x,elType)) { default: UD;
CASE(i8) CASE(i16) CASE(i32) CASE(f64)
CASE(c8) CASE(c16) CASE(c32)
case el_bit: {
#define HI GRADE_UD(1,0)
u64* xp = bitarr_ptr(x);
u64 i = bit_find(xp, xia, HI);
usz iw = i/64;
u64 m = ~(u64)0;
u64 d = GRADE_UD(,~)xp[iw] ^ (m<<(i%64));
if (iw == xia/64) return (d &~ (m<<(xia%64))) == 0;
if (d) return 0;
usz o = iw + 1;
usz l = xia - 64*o;
return (bit_find(xp+o, l, !HI) == l);
#undef HI
}
case el_B: {
B* xp = TO_BPTR(x);
CMP(compare(xp[i-1], xp[i]) GRADE_UD(>,<) 0)
}
}
#undef CASE
#undef CMP
}
// Location of first 1 (ascending) or 0 (descending), by binary search
static u64 CAT(bit_boundary,GRADE_UD(up,dn))(u64* x, u64 n) {
u64 c = GRADE_UD(,~)(u64)0;
u64 *s = x-1;
for (usz l = BIT_N(n)+1, h; (h=l/2)>0; l-=h) {
u64* m = s+h; if (!(c LT *m)) s = m;
}
++s; // Word containing boundary
u64 b = 64*(s-x);
if (b >= n) return n;
u64 v = GRADE_UD(~,) *s;
if (b+63 >= n) v &= ~(u64)0 >> ((-n)%64);
return b + POPC(v);
}
#define LE_C2 CAT(GRADE_UD(le,ge),c2)
extern B LE_C2(B,B,B);
extern B select_c2(B t, B w, B x);
extern B mul_c2(B, B, B);
B GRADE_CAT(c2)(B t, B w, B x) {
if (isAtm(w) || RNK(w)==0) thrM(GRADE_CHR": 𝕨 must have rank≥1");
if (isAtm(x)) x = m_unit(x);
@ -342,52 +414,138 @@ B GRADE_CAT(c2)(B t, B w, B x) {
u8 we = TI(w,elType); usz wia = IA(w);
u8 xe = TI(x,elType); usz xia = IA(x);
B r;
if (wia==0 | xia==0) {
Arr* ra=allZeroes(xia); arr_shCopy(ra, x);
r=taga(ra); goto done;
}
if (wia==1) {
B c = IGet(w, 0);
if (LIKELY(we<el_B & xe<el_B)) {
decG(w);
return LE_C2(m_f64(0), c, x);
} else {
SLOW2("𝕨"GRADE_CHR"𝕩", w, x); // Could narrow for mixed types
u64* rp; r = m_bitarrc(&rp, x);
B* xp = TO_BPTR(x);
u64 b = 0;
for (usz i = xia; ; ) {
i--;
b = 2*b + !(compare(xp[i], c) LT 0);
if (i%64 == 0) { rp[i/64]=b; if (!i) break; }
}
dec(c);
}
goto done;
}
if (wia>I32_MAX-10) thrM(GRADE_CHR": 𝕨 too big");
i32* rp; B r = m_i32arrc(&rp, x);
u8 fl = GRADE_UD(fl_asc,fl_dsc);
if (CHECK_VALID && !FL_HAS(w,fl)) {
if (!CAT(isSorted,GRADE_UD(Up,Down))(w)) thrM(GRADE_CHR": 𝕨 must be sorted"GRADE_UD(," in descending order"));
FL_SET(w, fl);
}
if (LIKELY(we<el_B & xe<el_B)) {
if (elNum(we)) { // number
#if SINGELI
B mult = bi_N;
#endif
if (elNum(we)) {
if (elNum(xe)) {
if (RARE(we==el_bit | xe==el_bit)) {
if (we==el_bit) {
usz c1 = CAT(bit_boundary,GRADE_UD(up,dn))(bitarr_ptr(w), wia);
decG(w); // c1 and wia contain all information in w
if (xe==el_bit) {
r = bit_sel(x, m_f64(GRADE_UD(c1,wia)), m_f64(GRADE_UD(wia,c1)));
} else {
i8* bp; B b01 = m_i8arrv(&bp, 2);
GRADE_UD(bp[0]=0; bp[1]=1;, bp[0]=1; bp[1]=0;)
B i = GRADE_CAT(c2)(m_f64(0), b01, x);
f64* c; B rw = m_f64arrv(&c, 3); c[0]=0; c[1]=c1; c[2]=wia;
r = select_c2(m_f64(0), i, num_squeeze(rw));
}
} else { // xe==el_bit: 2-element lookup table
i8* bp; B b01 = m_i8arrv(&bp, 2); bp[0]=0; bp[1]=1;
B i = GRADE_CAT(c2)(m_f64(0), w, b01);
SGetU(i)
r = bit_sel(x, GetU(i,0), GetU(i,1));
decG(i);
}
return r;
}
#if SINGELI
#define WIDEN(E, X) switch (E) { default:UD; case el_i16:X=toI16Any(X);break; case el_i32:X=toI32Any(X);break; case el_f64:X=toF64Any(X);break; }
if (xe > we) {
if (xia/4 < wia) { // Narrow x
assert(el_i8 <=we && we<=el_i32);
assert(el_i16<=xe && xe<=el_f64);
i32 pre = -1; pre<<=(8<<(we-el_i8))-1;
pre = GRADE_UD(pre,-1-pre); // Smallest value of w's type
i32 w0 = o2iG(IGetU(w,0));
// Saturation is correct except it can move low values past
// pre. Post-adjust with mult×r
if (w0 == pre) mult = LE_C2(m_f64(0), m_i32(pre), incG(x));
// Narrow x with saturating conversion
B xn; void *xp = m_tyarrc(&xn, elWidth(we), x, el2t(we));
u8 ind = xe<el_f64 ? (we-el_i8)+(xe-el_i16)
: 3 + 2*(we-el_i8) + GRADE_UD(0,1);
si_saturate[ind](xp, tyany_ptr(x), xia);
decG(x); x = xn;
} else {
WIDEN(xe, w)
we = xe;
}
}
if (we > xe) WIDEN(we, x)
#undef WIDEN
#else
if (!elInt(we) | !elInt(xe)) goto gen;
w=toI32Any(w); x=toI32Any(x);
#endif
} else {
// TODO pull copy-scalar part out of Reshape
i32* rp; r = m_i32arrc(&rp, x);
for (u64 i=0; i<xia; i++) rp[i]=wia;
goto done;
}
} else { // character
} else { // w is character
if (elNum(xe)) {
Arr* ra=allZeroes(xia); arr_shVec(ra);
decG(r); r=taga(ra); goto done;
Arr* ra=allZeroes(xia); arr_shCopy(ra, x);
r=taga(ra); goto done;
} else {
we = el_c32;
w=toC32Any(w); x=toC32Any(x);
}
}
#if SINGELI
u8 k = elWidthLogBits(we) - 3;
u8 rl = wia<128 ? 0 : wia<(1<<15) ? 1 : wia<(1<<31) ? 2 : 3;
void *rp = m_tyarrc(&r, 1<<rl, x, el2t(el_i8+rl));
si_bins[k*2 + GRADE_UD(0,1)](tyany_ptr(w), wia, tyany_ptr(x), xia, rp, rl);
if (!q_N(mult)) r = mul_c2(m_f64(0), mult, r);
#else
i32* rp; r = m_i32arrc(&rp, x);
i32* wi = tyany_ptr(w);
i32* xi = tyany_ptr(x);
if (CHECK_VALID && !FL_HAS(w,fl)) {
for (i64 i = 0; i < (i64)wia-1; i++) if (wi[i] GRADE_UD(>,<) wi[i+1]) thrM(GRADE_CHR": 𝕨 must be sorted"GRADE_UD(," in descending order"));
FL_SET(w, fl);
}
for (usz i = 0; i < xia; i++) {
i32 c = xi[i];
i32 *s = wi-1;
for (usz l = wia+1, h; (h=l/2)>0; l-=h) { i32* m = s+h; if (!(c LT *m)) s = m; }
rp[i] = s - (wi-1);
}
#endif
} else {
#if !SINGELI
gen:;
#endif
i32* rp; r = m_i32arrc(&rp, x);
SGetU(x)
SLOW2("𝕨"GRADE_CHR"𝕩", w, x);
B* wp = TO_BPTR(w);
if (CHECK_VALID && !FL_HAS(w,fl)) {
for (i64 i = 0; i < (i64)wia-1; i++) if (compare(wp[i], wp[i+1]) GRADE_UD(>,<) 0) thrM(GRADE_CHR": 𝕨 must be sorted"GRADE_UD(," in descending order"));
FL_SET(w, fl);
}
for (usz i = 0; i < xia; i++) {
B c = GetU(x,i);
usz s = 0, e = wia+1;
@ -404,6 +562,7 @@ done:
return r;
}
#undef GRADE_CHR
#undef LE_C2
#undef LT
#undef FOR

View File

@ -3,6 +3,11 @@
// Defines Sort, Grade, and Bins
#if SINGELI
#define SINGELI_FILE bins
#include "../utils/includeSingeli.h"
#endif
#define CAT0(A,B) A##_##B
#define CAT(A,B) CAT0(A,B)
typedef struct BI32p { B k; i32 v; } BI32p;

View File

@ -24,6 +24,11 @@ def ctz{x:T & isint{T} & width{T}==64} = emit{u8, '__builtin_ctzll', x}
def ctz{x:T & isint{T} & width{T}<=32} = emit{u8, '__builtin_ctz', x}
def clz{x:T & isint{T} & width{T}==64} = emit{u8, '__builtin_clzll', x}
def clz{x:T & isint{T} & width{T}<=32} = emit{u8, '__builtin_clz', x}
# count-leading-zeros complement, less type-dependent
def clzc{x:T & isint{T} & width{T}==64} = 64-clz{x}
def clzc{x:T & isint{T} & width{T}<=32} = 32-clz{x}
def ceil_log2{n} = clzc{n-1}
def truncBits{n, v & n<=8} = cast_i{u8, v}
def truncBits{n, v & n==16} = cast_i{u16, v}
@ -175,6 +180,7 @@ def cvt{...x} = assert{'cvt not supported', show{...x}}
def shuf{...x} = assert{'shuf not supported', show{...x}}
def shuf16Lo{...x} = assert{'shuf16Lo not supported', show{...x}}
def shuf16Hi{...x} = assert{'shuf16Hi not supported', show{...x}}
def shufHalves{...x} = assert{'shufHalves not supported', show{...x}}
def homAll{...x} = assert{'homAll not supported', show{...x}}
def homAny{...x} = assert{'homAny not supported', show{...x}}
def homMask{...x} = assert{'homMask not supported', show{...x}}
@ -281,6 +287,13 @@ def forNZ{vars,begin,end,iter} = {
++i
}
}
def for_backwards{vars,begin,end,iter} = {
i:u64 = end
while (i > begin) {
--i
iter{i, vars}
}
}
def forUnroll{exp,unr}{vars,begin,end,iter} = {
i:u64 = begin
while ((i+unr) <= end) {

View File

@ -0,0 +1,385 @@
include './base'
include './cbqnDefs'
if (hasarch{'AVX2'}) {
include './sse'
include './avx'
include './avx2'
}
include './mask'
include 'util/tup'
def for_dir{up} = if (up) for else for_backwards
def for_vec_overlap{vl}{vars,begin==0,n,iter} = {
assert{n >= vl}
def end = makelabel{}
j:u64 = 0
while (1) {
iter{j, vars}
j += vl
if (j > n-vl) { if (j == n) goto{end}; j = n-vl }
}
setlabel{end}
}
# Shift as u16, since x86 is missing 8-bit shifts
def shr16{v:V, n} = V~~(to_el{u16, v} >> n)
# Forward or backwards in-place max-scan
# Assumes a whole number of vectors and minimum 0
fn max_scan{T, up}(x:*T, len:u64) : void = {
def w = width{T}
if (hasarch{'AVX2'} and T!=u64) {
def op = max
# TODO unify with scan.singeli avx2_scan_idem
def rev{a} = if (up) a else (tuplen{a}-1)-reverse{a}
def maker{T, l} = make{T, rev{l}}
def sel8{v, t} = sel{[16]u8, v, maker{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,rev{n}}}
def spread{a:VT} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
op{b, sel{[8]i32, spread{b}, maker{[8]i32, 3*(3<iota{8})}}}
}
def toLast{n:VT} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**(up*7)}
else shuf{[4]u64, n, up*4b3333}
}
def vl = 256/w
def V = [vl]T
p := V**0
@for_dir{up} (v in *V~~x over len/vl) { v = op{pre{v}, p}; p = toLast{v} }
} else {
m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m }
}
}
def getsel{...x} = assert{'shuffling not supported', show{...x}}
if (hasarch{'AVX2'}) {
def getsel{h:H & lvec{H, 16, 8}} = {
v := pair{h,h}
{i} => sel{H, v, i}
}
def getsel{v:V & lvec{V, 32, 8}} = {
def H = v_half{V}
vtop := V**(vcount{V}/2)
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop}
}
def getsel{v:V & lvec{V, 8, 32}} = { {i} => sel{V, v, i} }
}
# Move evens to half 0 and odds to half 1
def uninterleave{x:V & hasarch{'AVX2'}} = {
def vl = vcount{V}; def bytes = width{eltype{V}}/8
def i = 2*iota{vl/4}
def i2= join{table{+, bytes*merge{i,i+1}, iota{bytes}}}
t := V~~sel{[16]u8, to_el{u8,x}, make{[32]u8, merge{i2,i2}}}
shuf{[4]u64, t, 4b3120}
}
def rtypes = tup{i8, i16, i32, f64}
# Return index of smallest possible result type given max result value
# (Unused; done in C for now)
def get_rtype{len} = {
t:u8 = 0
def c{T, ...ts} = if (len>maxvalue{T}) { ++t; c{...ts} }
def c{T==f64} = {}
c{...rtypes}
t
}
def rtype_arr{gen} = {
def t = each{gen, rtypes}
a:*(type{tupsel{0,t}}) = t
}
# Write the last index of v at t+v, for each unique v in w
fn write_indices{I,T}(t:*I, w:*T, n:u64) : void = {
def break = makelabel{}
i:u64 = 0; while (1) {
d:u64 = 16
id := i+d
wi := undefined{T}
if (id >= n) {
@for (w over j from i to n) store{t, w, j+1}
goto{break}
} else if ((wi = load{w, i}) == load{w, id}) {
# Gallop
md := n - i
d2 := undefined{u64}
while ((d2=d+d) < md and wi == load{w, i + d2}) d = d2
i += d
l := n - i; if (l > d) l = d
# Last instance of wi in [i,i+l); shrink l
while (l > 8) {
h := l/2
m := i + h
if (wi == load{w, m}) i = m
l -= h
}
} else {
@unroll (j to 8) store{t, load{w, i+j}, i+j+1}
i += 8
}
}
setlabel{break}
}
fn write_indices{I,T & width{I}==8}(t:*I, w:*T, n:u64) : void = {
@for (w over j to n) store{t, w, j+1}
}
def bins_lookup{I, T, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = {
# Build table
def tc = 1<<width{T}
t0:*I = talloc{I, tc}
@for (t0 over tc) t0 = 0
t:*I = t0 + tc/2
write_indices{I,T}(t, *T~~w, wn)
# Vector bit-table
def use_vectab = if (hasarch{'AVX2'} and I==i8 and T==i8) 1 else 0
def done = makelabel{}
if (use_vectab) bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done}
# Main scalar table
max_scan{I, up}(t0, tc)
@for (r in *I~~rp, x over xn) r = load{t, x}
if (use_vectab) setlabel{done}
tfree{t0}
}
def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
assert{wn < 128} # Total must fit in i8
def vl = 32
def T = i8
def V = [vl]T; def H = v_half{V}
def U = [vl]u8
# Convert to bit table
def no_bittab = makelabel{}
def nb = 256/vl
nu:u8 = 0; def addu{b} = { nu+=popc{b}; b } # Number of uniques
vb := U~~make{[nb](ty_u{vl}),
@collect (t in *V~~t0 over nb) addu{homMask{t > V**0}}
}
dup := promote{u64,nu} < wn
# Unique index to w index conversion
ui := undefined{V}; ui1 := undefined{V}; ui2 := each{undefined,tup{V,V}}
if (dup) {
def maxu = 2*vl
if (nu > maxu) goto{no_bittab}
# We'll subtract 1 when indexing so the initial 0 isn't needed
tui:*i8 = copy{maxu, 0}; i:T = 0
@for (tui over promote{u64,nu}) { i = load{t, load{w, i}}; tui = i }
def tv = bind{load, *V~~tui}
ui = tv{0}
if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232}
ui = shuf{[4]u64, ui, 4b1010}
if (nu > vl) ui2 = each{bind{shuf, [4]u64, tv{1}}, tup{4b1010, 4b3232}}
}
# Popcount on 8-bit values
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
def sum4 = getsel{make{H, sums{vl/2}}}
bot4 := U**0x0f
def vpopc{v} = {
def s{b} = sum4{b&bot4}
s{shr16{v,4}} + s{v}
}
# Bit table
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
def sel_b = getsel{swap{vb}}
# Masks for filtering bit table
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1
def sel_m = getsel{make{H, merge{ms - 256*(ms>127), 8**0}}}
# Exact values for multiples of 8
store{*U~~t0, 0, vpopc{vb}}
st:i8=0; @for_dir{up} (t0 over 256/8) { st += t0; t0 = st }
def sel_c = getsel{swap{load{*V~~t0, 0} - V**dup}}
# Top 5 bits select bytes from tables; bottom 3 select from mask
bot3 := U**0x07
@for_vec_overlap{vl} (j to xn) {
xv := load{*U~~(x+j), 0}
xb := xv & bot3
xt := shr16{xv &~ bot3, 3}
ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}}
if (dup) {
i0 := V~~ind # Can contain -1
def isel{u} = sel{H, u, i0}
ind = isel{ui}
if (nu > 16) {
b := V~~(to_el{u16, i0} << (7 - lb{vl/2}))
ind = topBlend{ind, isel{ui1}, b}
if (nu > 32) ind = homBlend{topBlend{...each{isel,ui2}, b}, ind, i0 < V**vl}
}
}
store{*U~~(*T~~rp+j), 0, ind}
}
goto{done}
setlabel{no_bittab}
}
# Binary search within vector registers
def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
assert{wn > 1}; assert{wn < maxwn}
def wd = width{T}
def I = if (wd<32) u8 else u32; def wi = width{I}
def lanes = hasarch{'AVX2'} & (I==u8)
def isub = wd/wi; def bb = bind{base,1<<wi}
def vl = 256/wd; def svl = vl>>lanes
def V = [vl]T
def U = [vl](ty_u{T})
def lt = if (up) <; else >
# Number of steps
log := ceil_log2{wn+1}
gap := 1<<log - cast_i{u8, wn}
# Fill with minimum value at the beginning
def pre = (if (up) minvalue else maxvalue){T}
wg := *V~~(w-gap)
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
# For multiple lanes, interleave like transpose
def maxstep = lb{maxwn}
def lstep = lb{svl}
def ex = maxstep - lstep
wv := if (lanes) wv0 else tup{wv0,wv0}
wv2 := wv # Compiler complains if uninitialized
if (ex>=1 and wn >= svl) {
--gap # Allows subtracting < instead of adding <=
def un = uninterleave
def tr_half{a, b} = each{bind{shufHalves,a,b}, tup{16b20, 16b31}}
def un{{a,b}} = tr_half{un{a},un{b}}
if (not lanes) tupsel{1,wv} = load{wg, 1}
wv = un{wv}
if (ex>=2 and wn >= 2*svl) {
assert{lanes} # Different transpose pattern needed
gap -= 2
tup{wv, wv2} = each{un, tr_half{wv, un{load{wg, 1}}}}
}
}
def ms{v}{h} = getsel{to_el{I, if (lanes) half{v,h} else tupsel{h,v}}}
def selw = ms{wv}{0}; def selw1 = if (ex>=1) ms{wv}{1} else 'undef'
def selw2 = if (ex>=2) each{ms{wv2}, iota{2}} else 'undef'
# Offset at end
off := U~~V**i8~~(gap - 1)
# Midpoint bits for each step
def lowbits = bb{copy{isub,isub}}
bits := each{{j} => U**(lowbits << j), iota{lstep}}
# Unroll sizes up to a full lane, handling extra lanes conditionally
# in the largest one
@unroll (klog from 2 to min{maxstep,lstep}+1) {
def last = klog==lstep
def this = if (not last) log==klog else log>=klog
if (this) @for_vec_overlap{vl} (j to xn) {
xv:= load{*V~~(x+j), 0}
s := U**bb{iota{isub}} # Select sequential bytes within each U
def ltx{se,ind} = lt{xv, V~~se{to_el{I,ind}}}
@unroll (j to klog) {
m := s | tupsel{klog-1-j,bits}
s = homBlend{m, s, ltx{selw, m}}
}
r := if (isub==1) s else s>>(lb{isub}+wd-wi)
# Extra selection lanes
if (last and ex>=1 and log>=klog+1) {
r += r
c := ltx{selw1,s}
if (ex>=2 and log>=klog+2) {
r += r
each{{se} => c += ltx{se,s}, selw2}
}
r += c
}
r -= off
rn := if (T==i8) r
else if (T==i16) half{narrow{u8, r}, 0}
else extract{to_el{i64, narrow{u8, r}}, 0}
store{*type{rn}~~(*i8~~rp+j), 0, rn}
}
}
}
def unroll_sizes = tup{4,1}
fn write{T,k}(r:*void, i:u64, ...vs:k**u64) : void = {
each{{j,v} => store{*T~~r, i+j, cast_i{T,v}}, iota{k}, vs}
}
def wr_arrs = each{{k} => rtype_arr{{T} => write{T,k}}, unroll_sizes}
def bin_search_branchless{up, w, wn, x, n, res, rtype} = {
def lt = if (up) <; else >
ws := w - 1
l0 := wn + 1
# Take a list of indices in x/res to allow unrolling
def search{inds} = {
xs:= each{bind{load,x}, inds} # Values
ss:= each{{_}=>ws, inds} # Initial lower bound
l := l0; h := undefined{u64} # Interval size l, same for all values
while ((h=l/2) > 0) {
# Branchless update
def bin1{s, x, m} = { if (not lt{x, load{m}}) s = m }
each{bin1, ss, xs, each{bind{+,h}, ss}}
l -= h
}
each{{s} => u64~~(s - ws), ss}
}
# Unroll by 4 then 1
def search{i, k} = search{each{bind{+,i}, iota{k}}}
j:u64 = 0
def searches{k, wr_arr} = {
wr := load{wr_arr, rtype}
while (j+k <= n) { wr(res, j, ...search{j, k}); j+=k }
}
each{searches, unroll_sizes, wr_arrs}
}
fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
def param = tup{up, *T~~w, wn, *T~~x, xn, rp}
def lookup{k} = {
if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param}
else if (k+1 < tuplen{rtypes}) lookup{k+1}
}
# For >=8 i8 values, vector bit-table is as good as binary search
def wn_vec = if (T==i8) 8 else 2*256/width{T}
if (hasarch{'AVX2'} and T<=i32 and wn < wn_vec and xn >= 256/width{T}) {
bin_search_vec{T, ...param, wn_vec}
# Lookup table threshold has to account for cost of
# populating the table (proportional to wn until it's large), and
# initializing the table (constant, much higher for i16)
} else if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) {
lookup{0}
} else if (T==i16 and xn>=512 and (xn>=1<<14 or xn >= wn>>6 + (u64~~3<<(12+rty))/promote{u64,ceil_log2{wn}+2})) {
lookup{0}
} else {
bin_search_branchless{...param, rty}
}
}
exportT{
'si_bins',
join{table{bins, tup{i8,i16,i32,f64}, tup{1,0}}}
}
# Utility for narrowing binary search right argument
include './f64'
require{'math.h'}
fn saturate{F,T,...up}(dst:*void, src:*void, n:u64) : void = {
# Auto-vectorizes, although not that well for f64
def a = minvalue{T}; af := cast_i{F,a}
def b = maxvalue{T}; bf := cast_i{F,b}
@for (d in *T~~dst, xf in *F~~src over n) {
x := if (F==f64) (if (tupsel{0,up}) floor else ceil){xf} else xf
d = cast_i{T, x}
if (x<af) d = a
if (x>bf) d = b
}
}
exportT{
'si_saturate',
each{{a}=>saturate{...a}, merge{
tup{tup{i16,i8}, tup{i32,i8}, tup{i32,i16}},
join{table{bind{tup,f64}, tup{i8,i16,i32}, tup{1,0}}}
}}
}

View File

@ -31,3 +31,11 @@ def cbqn_elType{T & T==u16} = 6
def cbqn_elType{T & T==u32} = 7
def cbqn_tyArrOffset{} = emit{u64, 'offsetof', 'TyArr', 'a'}
def talloc{T, len} = emit{*T, 'TALLOCP', fmt_type{T}, len}
def tfree{ptr} = emit{void, 'TFREE', ptr}
def fmt_type{T} = {
def w = match (width{T}) { {_==8}=>'8'; {_==16}=>'16'; {_==32}=>'32'; {_==64}=>'64' }
merge{quality{T}, w}
}
def fmt_type{T & isptr{T}} = merge{'*',fmt_type{eltype{T}}}