commit
368e6b6001
@ -634,6 +634,7 @@ cachedBin‿linkerCache ← {
|
||||
"xa."‿"src/builtins/squeeze.c"‿"squeeze", "xa."‿"src/utils/mut.c"‿"copy",
|
||||
"xa."‿"src/utils/bits.c"‿"bits", "xag"‿"src/builtins/transpose.c"‿"transpose",
|
||||
"xag"‿"src/builtins/search.c"‿"search", "xa."‿"src/builtins/fold.c"‿"fold",
|
||||
"xag"‿"src/builtins/sort.c"‿"bins"
|
||||
|
||||
"2.."‿"src/builtins/select.c"‿"select", "2.."‿"src/builtins/scan.c"‿"scan",
|
||||
"2.."‿"src/builtins/slash.c"‿"constrep", "2.."‿"src/builtins/scan.c"‿"neq",
|
||||
|
||||
@ -19,15 +19,31 @@
|
||||
// SHOULD widen odd cell sizes under 8 bytes in sort and grade
|
||||
|
||||
// Bins
|
||||
// Length 0 or 1 𝕨: trivial, or comparison
|
||||
// Stand-alone 𝕨 sortedness check
|
||||
// SHOULD vectorize sortedness check on lists of numbers
|
||||
// Mixed integer and character arguments gives all 0 or ≠𝕨
|
||||
// Integers and characters: 4-byte branchless binary search
|
||||
// Non-Singeli, integers and characters:
|
||||
// 4-byte branchless binary search, 4-byte output
|
||||
// SHOULD support fast character searches
|
||||
// Boolean 𝕨 or 𝕩: lookup table (single binary search on boolean 𝕨)
|
||||
// Different widths: generally widen narrower argument
|
||||
// Narrow wider-type 𝕩 instead if it isn't much shorter
|
||||
// SHOULD trim wider-type 𝕨 and possibly narrow
|
||||
// Same-width numbers:
|
||||
// Output type based on ≠𝕨
|
||||
// Short 𝕨: vector binary search (then linear on extra lanes)
|
||||
// 1- or 2-byte type, long enough 𝕩: lookup table from ⌈`
|
||||
// Binary gallops to skip long repeated elements of 𝕨
|
||||
// 1-byte, no duplicates or few uniques: vector bit-table lookup
|
||||
// General: interleaved branchless binary search
|
||||
// COULD start interleaved search with a vector binary round
|
||||
// General case: branching binary search
|
||||
// SHOULD implement f64 branchless binary search
|
||||
// SHOULD interleave multiple branchless binary searches
|
||||
// SHOULD specialize bins on equal types at least
|
||||
// SHOULD implement table-based ⍋⍒ for small-range 𝕨
|
||||
// SHOULD special-case short 𝕨
|
||||
// COULD trim 𝕨 based on range of 𝕩
|
||||
// COULD optimize small-range 𝕨 with small-type methods
|
||||
// SHOULD partition 𝕩 when 𝕨 is large
|
||||
// COULD interpolation search for large 𝕩 and short 𝕨
|
||||
// COULD use linear search and galloping for sorted 𝕩
|
||||
|
||||
#define GRADE_CAT(N) CAT(GRADE_UD(gradeUp,gradeDown),N)
|
||||
#define GRADE_NEG GRADE_UD(,-)
|
||||
@ -326,6 +342,62 @@ B GRADE_CAT(c1)(B t, B x) {
|
||||
}
|
||||
|
||||
|
||||
bool CAT(isSorted,GRADE_UD(Up,Down))(B x) {
|
||||
assert(isArr(x) && RNK(x)==1); // TODO extend to >=1
|
||||
usz xia = IA(x);
|
||||
if (xia <= 1) return 1;
|
||||
|
||||
#define CMP(TEST) \
|
||||
for (usz i=1; i<xia; i++) if (TEST) return 0; \
|
||||
return 1;
|
||||
#define CASE(T) case el_##T: { \
|
||||
T* xp = T##any_ptr(x); CMP(xp[i-1] GRADE_UD(>,<) xp[i]) }
|
||||
switch (TI(x,elType)) { default: UD;
|
||||
CASE(i8) CASE(i16) CASE(i32) CASE(f64)
|
||||
CASE(c8) CASE(c16) CASE(c32)
|
||||
case el_bit: {
|
||||
#define HI GRADE_UD(1,0)
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64 i = bit_find(xp, xia, HI);
|
||||
usz iw = i/64;
|
||||
u64 m = ~(u64)0;
|
||||
u64 d = GRADE_UD(,~)xp[iw] ^ (m<<(i%64));
|
||||
if (iw == xia/64) return (d &~ (m<<(xia%64))) == 0;
|
||||
if (d) return 0;
|
||||
usz o = iw + 1;
|
||||
usz l = xia - 64*o;
|
||||
return (bit_find(xp+o, l, !HI) == l);
|
||||
#undef HI
|
||||
}
|
||||
case el_B: {
|
||||
B* xp = TO_BPTR(x);
|
||||
CMP(compare(xp[i-1], xp[i]) GRADE_UD(>,<) 0)
|
||||
}
|
||||
}
|
||||
#undef CASE
|
||||
#undef CMP
|
||||
}
|
||||
|
||||
// Location of first 1 (ascending) or 0 (descending), by binary search
|
||||
static u64 CAT(bit_boundary,GRADE_UD(up,dn))(u64* x, u64 n) {
|
||||
u64 c = GRADE_UD(,~)(u64)0;
|
||||
u64 *s = x-1;
|
||||
for (usz l = BIT_N(n)+1, h; (h=l/2)>0; l-=h) {
|
||||
u64* m = s+h; if (!(c LT *m)) s = m;
|
||||
}
|
||||
++s; // Word containing boundary
|
||||
u64 b = 64*(s-x);
|
||||
if (b >= n) return n;
|
||||
u64 v = GRADE_UD(~,) *s;
|
||||
if (b+63 >= n) v &= ~(u64)0 >> ((-n)%64);
|
||||
return b + POPC(v);
|
||||
}
|
||||
|
||||
#define LE_C2 CAT(GRADE_UD(le,ge),c2)
|
||||
extern B LE_C2(B,B,B);
|
||||
extern B select_c2(B t, B w, B x);
|
||||
extern B mul_c2(B, B, B);
|
||||
|
||||
B GRADE_CAT(c2)(B t, B w, B x) {
|
||||
if (isAtm(w) || RNK(w)==0) thrM(GRADE_CHR": 𝕨 must have rank≥1");
|
||||
if (isAtm(x)) x = m_unit(x);
|
||||
@ -342,52 +414,138 @@ B GRADE_CAT(c2)(B t, B w, B x) {
|
||||
u8 we = TI(w,elType); usz wia = IA(w);
|
||||
u8 xe = TI(x,elType); usz xia = IA(x);
|
||||
|
||||
B r;
|
||||
if (wia==0 | xia==0) {
|
||||
Arr* ra=allZeroes(xia); arr_shCopy(ra, x);
|
||||
r=taga(ra); goto done;
|
||||
}
|
||||
if (wia==1) {
|
||||
B c = IGet(w, 0);
|
||||
if (LIKELY(we<el_B & xe<el_B)) {
|
||||
decG(w);
|
||||
return LE_C2(m_f64(0), c, x);
|
||||
} else {
|
||||
SLOW2("𝕨"GRADE_CHR"𝕩", w, x); // Could narrow for mixed types
|
||||
u64* rp; r = m_bitarrc(&rp, x);
|
||||
B* xp = TO_BPTR(x);
|
||||
u64 b = 0;
|
||||
for (usz i = xia; ; ) {
|
||||
i--;
|
||||
b = 2*b + !(compare(xp[i], c) LT 0);
|
||||
if (i%64 == 0) { rp[i/64]=b; if (!i) break; }
|
||||
}
|
||||
dec(c);
|
||||
}
|
||||
goto done;
|
||||
}
|
||||
if (wia>I32_MAX-10) thrM(GRADE_CHR": 𝕨 too big");
|
||||
i32* rp; B r = m_i32arrc(&rp, x);
|
||||
|
||||
u8 fl = GRADE_UD(fl_asc,fl_dsc);
|
||||
|
||||
if (CHECK_VALID && !FL_HAS(w,fl)) {
|
||||
if (!CAT(isSorted,GRADE_UD(Up,Down))(w)) thrM(GRADE_CHR": 𝕨 must be sorted"GRADE_UD(," in descending order"));
|
||||
FL_SET(w, fl);
|
||||
}
|
||||
|
||||
if (LIKELY(we<el_B & xe<el_B)) {
|
||||
if (elNum(we)) { // number
|
||||
#if SINGELI
|
||||
B mult = bi_N;
|
||||
#endif
|
||||
if (elNum(we)) {
|
||||
if (elNum(xe)) {
|
||||
if (RARE(we==el_bit | xe==el_bit)) {
|
||||
if (we==el_bit) {
|
||||
usz c1 = CAT(bit_boundary,GRADE_UD(up,dn))(bitarr_ptr(w), wia);
|
||||
decG(w); // c1 and wia contain all information in w
|
||||
if (xe==el_bit) {
|
||||
r = bit_sel(x, m_f64(GRADE_UD(c1,wia)), m_f64(GRADE_UD(wia,c1)));
|
||||
} else {
|
||||
i8* bp; B b01 = m_i8arrv(&bp, 2);
|
||||
GRADE_UD(bp[0]=0; bp[1]=1;, bp[0]=1; bp[1]=0;)
|
||||
B i = GRADE_CAT(c2)(m_f64(0), b01, x);
|
||||
f64* c; B rw = m_f64arrv(&c, 3); c[0]=0; c[1]=c1; c[2]=wia;
|
||||
r = select_c2(m_f64(0), i, num_squeeze(rw));
|
||||
}
|
||||
} else { // xe==el_bit: 2-element lookup table
|
||||
i8* bp; B b01 = m_i8arrv(&bp, 2); bp[0]=0; bp[1]=1;
|
||||
B i = GRADE_CAT(c2)(m_f64(0), w, b01);
|
||||
SGetU(i)
|
||||
r = bit_sel(x, GetU(i,0), GetU(i,1));
|
||||
decG(i);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
#if SINGELI
|
||||
#define WIDEN(E, X) switch (E) { default:UD; case el_i16:X=toI16Any(X);break; case el_i32:X=toI32Any(X);break; case el_f64:X=toF64Any(X);break; }
|
||||
if (xe > we) {
|
||||
if (xia/4 < wia) { // Narrow x
|
||||
assert(el_i8 <=we && we<=el_i32);
|
||||
assert(el_i16<=xe && xe<=el_f64);
|
||||
i32 pre = -1; pre<<=(8<<(we-el_i8))-1;
|
||||
pre = GRADE_UD(pre,-1-pre); // Smallest value of w's type
|
||||
i32 w0 = o2iG(IGetU(w,0));
|
||||
// Saturation is correct except it can move low values past
|
||||
// pre. Post-adjust with mult×r
|
||||
if (w0 == pre) mult = LE_C2(m_f64(0), m_i32(pre), incG(x));
|
||||
// Narrow x with saturating conversion
|
||||
B xn; void *xp = m_tyarrc(&xn, elWidth(we), x, el2t(we));
|
||||
u8 ind = xe<el_f64 ? (we-el_i8)+(xe-el_i16)
|
||||
: 3 + 2*(we-el_i8) + GRADE_UD(0,1);
|
||||
si_saturate[ind](xp, tyany_ptr(x), xia);
|
||||
decG(x); x = xn;
|
||||
} else {
|
||||
WIDEN(xe, w)
|
||||
we = xe;
|
||||
}
|
||||
}
|
||||
if (we > xe) WIDEN(we, x)
|
||||
#undef WIDEN
|
||||
#else
|
||||
if (!elInt(we) | !elInt(xe)) goto gen;
|
||||
w=toI32Any(w); x=toI32Any(x);
|
||||
#endif
|
||||
} else {
|
||||
// TODO pull copy-scalar part out of Reshape
|
||||
i32* rp; r = m_i32arrc(&rp, x);
|
||||
for (u64 i=0; i<xia; i++) rp[i]=wia;
|
||||
goto done;
|
||||
}
|
||||
} else { // character
|
||||
} else { // w is character
|
||||
if (elNum(xe)) {
|
||||
Arr* ra=allZeroes(xia); arr_shVec(ra);
|
||||
decG(r); r=taga(ra); goto done;
|
||||
Arr* ra=allZeroes(xia); arr_shCopy(ra, x);
|
||||
r=taga(ra); goto done;
|
||||
} else {
|
||||
we = el_c32;
|
||||
w=toC32Any(w); x=toC32Any(x);
|
||||
}
|
||||
}
|
||||
|
||||
#if SINGELI
|
||||
u8 k = elWidthLogBits(we) - 3;
|
||||
u8 rl = wia<128 ? 0 : wia<(1<<15) ? 1 : wia<(1<<31) ? 2 : 3;
|
||||
void *rp = m_tyarrc(&r, 1<<rl, x, el2t(el_i8+rl));
|
||||
si_bins[k*2 + GRADE_UD(0,1)](tyany_ptr(w), wia, tyany_ptr(x), xia, rp, rl);
|
||||
if (!q_N(mult)) r = mul_c2(m_f64(0), mult, r);
|
||||
#else
|
||||
i32* rp; r = m_i32arrc(&rp, x);
|
||||
i32* wi = tyany_ptr(w);
|
||||
i32* xi = tyany_ptr(x);
|
||||
if (CHECK_VALID && !FL_HAS(w,fl)) {
|
||||
for (i64 i = 0; i < (i64)wia-1; i++) if (wi[i] GRADE_UD(>,<) wi[i+1]) thrM(GRADE_CHR": 𝕨 must be sorted"GRADE_UD(," in descending order"));
|
||||
FL_SET(w, fl);
|
||||
}
|
||||
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
i32 c = xi[i];
|
||||
i32 *s = wi-1;
|
||||
for (usz l = wia+1, h; (h=l/2)>0; l-=h) { i32* m = s+h; if (!(c LT *m)) s = m; }
|
||||
rp[i] = s - (wi-1);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
#if !SINGELI
|
||||
gen:;
|
||||
#endif
|
||||
i32* rp; r = m_i32arrc(&rp, x);
|
||||
|
||||
SGetU(x)
|
||||
SLOW2("𝕨"GRADE_CHR"𝕩", w, x);
|
||||
B* wp = TO_BPTR(w);
|
||||
|
||||
if (CHECK_VALID && !FL_HAS(w,fl)) {
|
||||
for (i64 i = 0; i < (i64)wia-1; i++) if (compare(wp[i], wp[i+1]) GRADE_UD(>,<) 0) thrM(GRADE_CHR": 𝕨 must be sorted"GRADE_UD(," in descending order"));
|
||||
FL_SET(w, fl);
|
||||
}
|
||||
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
B c = GetU(x,i);
|
||||
usz s = 0, e = wia+1;
|
||||
@ -404,6 +562,7 @@ done:
|
||||
return r;
|
||||
}
|
||||
#undef GRADE_CHR
|
||||
#undef LE_C2
|
||||
|
||||
#undef LT
|
||||
#undef FOR
|
||||
|
||||
@ -3,6 +3,11 @@
|
||||
|
||||
// Defines Sort, Grade, and Bins
|
||||
|
||||
#if SINGELI
|
||||
#define SINGELI_FILE bins
|
||||
#include "../utils/includeSingeli.h"
|
||||
#endif
|
||||
|
||||
#define CAT0(A,B) A##_##B
|
||||
#define CAT(A,B) CAT0(A,B)
|
||||
typedef struct BI32p { B k; i32 v; } BI32p;
|
||||
|
||||
@ -24,6 +24,11 @@ def ctz{x:T & isint{T} & width{T}==64} = emit{u8, '__builtin_ctzll', x}
|
||||
def ctz{x:T & isint{T} & width{T}<=32} = emit{u8, '__builtin_ctz', x}
|
||||
def clz{x:T & isint{T} & width{T}==64} = emit{u8, '__builtin_clzll', x}
|
||||
def clz{x:T & isint{T} & width{T}<=32} = emit{u8, '__builtin_clz', x}
|
||||
# count-leading-zeros complement, less type-dependent
|
||||
def clzc{x:T & isint{T} & width{T}==64} = 64-clz{x}
|
||||
def clzc{x:T & isint{T} & width{T}<=32} = 32-clz{x}
|
||||
|
||||
def ceil_log2{n} = clzc{n-1}
|
||||
|
||||
def truncBits{n, v & n<=8} = cast_i{u8, v}
|
||||
def truncBits{n, v & n==16} = cast_i{u16, v}
|
||||
@ -175,6 +180,7 @@ def cvt{...x} = assert{'cvt not supported', show{...x}}
|
||||
def shuf{...x} = assert{'shuf not supported', show{...x}}
|
||||
def shuf16Lo{...x} = assert{'shuf16Lo not supported', show{...x}}
|
||||
def shuf16Hi{...x} = assert{'shuf16Hi not supported', show{...x}}
|
||||
def shufHalves{...x} = assert{'shufHalves not supported', show{...x}}
|
||||
def homAll{...x} = assert{'homAll not supported', show{...x}}
|
||||
def homAny{...x} = assert{'homAny not supported', show{...x}}
|
||||
def homMask{...x} = assert{'homMask not supported', show{...x}}
|
||||
@ -281,6 +287,13 @@ def forNZ{vars,begin,end,iter} = {
|
||||
++i
|
||||
}
|
||||
}
|
||||
def for_backwards{vars,begin,end,iter} = {
|
||||
i:u64 = end
|
||||
while (i > begin) {
|
||||
--i
|
||||
iter{i, vars}
|
||||
}
|
||||
}
|
||||
def forUnroll{exp,unr}{vars,begin,end,iter} = {
|
||||
i:u64 = begin
|
||||
while ((i+unr) <= end) {
|
||||
|
||||
385
src/singeli/src/bins.singeli
Normal file
385
src/singeli/src/bins.singeli
Normal file
@ -0,0 +1,385 @@
|
||||
include './base'
|
||||
include './cbqnDefs'
|
||||
if (hasarch{'AVX2'}) {
|
||||
include './sse'
|
||||
include './avx'
|
||||
include './avx2'
|
||||
}
|
||||
include './mask'
|
||||
include 'util/tup'
|
||||
|
||||
def for_dir{up} = if (up) for else for_backwards
|
||||
|
||||
def for_vec_overlap{vl}{vars,begin==0,n,iter} = {
|
||||
assert{n >= vl}
|
||||
def end = makelabel{}
|
||||
j:u64 = 0
|
||||
while (1) {
|
||||
iter{j, vars}
|
||||
j += vl
|
||||
if (j > n-vl) { if (j == n) goto{end}; j = n-vl }
|
||||
}
|
||||
setlabel{end}
|
||||
}
|
||||
|
||||
# Shift as u16, since x86 is missing 8-bit shifts
|
||||
def shr16{v:V, n} = V~~(to_el{u16, v} >> n)
|
||||
|
||||
# Forward or backwards in-place max-scan
|
||||
# Assumes a whole number of vectors and minimum 0
|
||||
fn max_scan{T, up}(x:*T, len:u64) : void = {
|
||||
def w = width{T}
|
||||
if (hasarch{'AVX2'} and T!=u64) {
|
||||
def op = max
|
||||
# TODO unify with scan.singeli avx2_scan_idem
|
||||
def rev{a} = if (up) a else (tuplen{a}-1)-reverse{a}
|
||||
def maker{T, l} = make{T, rev{l}}
|
||||
def sel8{v, t} = sel{[16]u8, v, maker{[32]i8, t}}
|
||||
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
|
||||
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,rev{n}}}
|
||||
def spread{a:VT} = {
|
||||
def w = elwidth{VT}
|
||||
def b = w/8
|
||||
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
||||
}
|
||||
def shift{k,l} = merge{iota{k},iota{l-k}}
|
||||
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
|
||||
def pre{a} = {
|
||||
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
|
||||
op{b, sel{[8]i32, spread{b}, maker{[8]i32, 3*(3<iota{8})}}}
|
||||
}
|
||||
def toLast{n:VT} = {
|
||||
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**(up*7)}
|
||||
else shuf{[4]u64, n, up*4b3333}
|
||||
}
|
||||
def vl = 256/w
|
||||
def V = [vl]T
|
||||
p := V**0
|
||||
@for_dir{up} (v in *V~~x over len/vl) { v = op{pre{v}, p}; p = toLast{v} }
|
||||
} else {
|
||||
m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m }
|
||||
}
|
||||
}
|
||||
|
||||
def getsel{...x} = assert{'shuffling not supported', show{...x}}
|
||||
if (hasarch{'AVX2'}) {
|
||||
def getsel{h:H & lvec{H, 16, 8}} = {
|
||||
v := pair{h,h}
|
||||
{i} => sel{H, v, i}
|
||||
}
|
||||
def getsel{v:V & lvec{V, 32, 8}} = {
|
||||
def H = v_half{V}
|
||||
vtop := V**(vcount{V}/2)
|
||||
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
|
||||
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop}
|
||||
}
|
||||
def getsel{v:V & lvec{V, 8, 32}} = { {i} => sel{V, v, i} }
|
||||
}
|
||||
|
||||
# Move evens to half 0 and odds to half 1
|
||||
def uninterleave{x:V & hasarch{'AVX2'}} = {
|
||||
def vl = vcount{V}; def bytes = width{eltype{V}}/8
|
||||
def i = 2*iota{vl/4}
|
||||
def i2= join{table{+, bytes*merge{i,i+1}, iota{bytes}}}
|
||||
t := V~~sel{[16]u8, to_el{u8,x}, make{[32]u8, merge{i2,i2}}}
|
||||
shuf{[4]u64, t, 4b3120}
|
||||
}
|
||||
|
||||
def rtypes = tup{i8, i16, i32, f64}
|
||||
# Return index of smallest possible result type given max result value
|
||||
# (Unused; done in C for now)
|
||||
def get_rtype{len} = {
|
||||
t:u8 = 0
|
||||
def c{T, ...ts} = if (len>maxvalue{T}) { ++t; c{...ts} }
|
||||
def c{T==f64} = {}
|
||||
c{...rtypes}
|
||||
t
|
||||
}
|
||||
def rtype_arr{gen} = {
|
||||
def t = each{gen, rtypes}
|
||||
a:*(type{tupsel{0,t}}) = t
|
||||
}
|
||||
|
||||
# Write the last index of v at t+v, for each unique v in w
|
||||
fn write_indices{I,T}(t:*I, w:*T, n:u64) : void = {
|
||||
def break = makelabel{}
|
||||
i:u64 = 0; while (1) {
|
||||
d:u64 = 16
|
||||
id := i+d
|
||||
wi := undefined{T}
|
||||
if (id >= n) {
|
||||
@for (w over j from i to n) store{t, w, j+1}
|
||||
goto{break}
|
||||
} else if ((wi = load{w, i}) == load{w, id}) {
|
||||
# Gallop
|
||||
md := n - i
|
||||
d2 := undefined{u64}
|
||||
while ((d2=d+d) < md and wi == load{w, i + d2}) d = d2
|
||||
i += d
|
||||
l := n - i; if (l > d) l = d
|
||||
# Last instance of wi in [i,i+l); shrink l
|
||||
while (l > 8) {
|
||||
h := l/2
|
||||
m := i + h
|
||||
if (wi == load{w, m}) i = m
|
||||
l -= h
|
||||
}
|
||||
} else {
|
||||
@unroll (j to 8) store{t, load{w, i+j}, i+j+1}
|
||||
i += 8
|
||||
}
|
||||
}
|
||||
setlabel{break}
|
||||
}
|
||||
fn write_indices{I,T & width{I}==8}(t:*I, w:*T, n:u64) : void = {
|
||||
@for (w over j to n) store{t, w, j+1}
|
||||
}
|
||||
def bins_lookup{I, T, up, w:*T, wn:u64, x:*T, xn:u64, rp:*void} = {
|
||||
# Build table
|
||||
def tc = 1<<width{T}
|
||||
t0:*I = talloc{I, tc}
|
||||
@for (t0 over tc) t0 = 0
|
||||
t:*I = t0 + tc/2
|
||||
write_indices{I,T}(t, *T~~w, wn)
|
||||
# Vector bit-table
|
||||
def use_vectab = if (hasarch{'AVX2'} and I==i8 and T==i8) 1 else 0
|
||||
def done = makelabel{}
|
||||
if (use_vectab) bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done}
|
||||
# Main scalar table
|
||||
max_scan{I, up}(t0, tc)
|
||||
@for (r in *I~~rp, x over xn) r = load{t, x}
|
||||
if (use_vectab) setlabel{done}
|
||||
tfree{t0}
|
||||
}
|
||||
|
||||
def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
|
||||
assert{wn < 128} # Total must fit in i8
|
||||
def vl = 32
|
||||
def T = i8
|
||||
def V = [vl]T; def H = v_half{V}
|
||||
def U = [vl]u8
|
||||
|
||||
# Convert to bit table
|
||||
def no_bittab = makelabel{}
|
||||
def nb = 256/vl
|
||||
nu:u8 = 0; def addu{b} = { nu+=popc{b}; b } # Number of uniques
|
||||
vb := U~~make{[nb](ty_u{vl}),
|
||||
@collect (t in *V~~t0 over nb) addu{homMask{t > V**0}}
|
||||
}
|
||||
dup := promote{u64,nu} < wn
|
||||
# Unique index to w index conversion
|
||||
ui := undefined{V}; ui1 := undefined{V}; ui2 := each{undefined,tup{V,V}}
|
||||
if (dup) {
|
||||
def maxu = 2*vl
|
||||
if (nu > maxu) goto{no_bittab}
|
||||
# We'll subtract 1 when indexing so the initial 0 isn't needed
|
||||
tui:*i8 = copy{maxu, 0}; i:T = 0
|
||||
@for (tui over promote{u64,nu}) { i = load{t, load{w, i}}; tui = i }
|
||||
def tv = bind{load, *V~~tui}
|
||||
ui = tv{0}
|
||||
if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232}
|
||||
ui = shuf{[4]u64, ui, 4b1010}
|
||||
if (nu > vl) ui2 = each{bind{shuf, [4]u64, tv{1}}, tup{4b1010, 4b3232}}
|
||||
}
|
||||
# Popcount on 8-bit values
|
||||
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
|
||||
def sum4 = getsel{make{H, sums{vl/2}}}
|
||||
bot4 := U**0x0f
|
||||
def vpopc{v} = {
|
||||
def s{b} = sum4{b&bot4}
|
||||
s{shr16{v,4}} + s{v}
|
||||
}
|
||||
# Bit table
|
||||
def swap{v} = shuf{[4]u64, v, 4b1032} # For signedness
|
||||
def sel_b = getsel{swap{vb}}
|
||||
# Masks for filtering bit table
|
||||
def ms = if (up) 256-(1<<(1+iota{8})) else (1<<iota{8})-1
|
||||
def sel_m = getsel{make{H, merge{ms - 256*(ms>127), 8**0}}}
|
||||
# Exact values for multiples of 8
|
||||
store{*U~~t0, 0, vpopc{vb}}
|
||||
st:i8=0; @for_dir{up} (t0 over 256/8) { st += t0; t0 = st }
|
||||
def sel_c = getsel{swap{load{*V~~t0, 0} - V**dup}}
|
||||
# Top 5 bits select bytes from tables; bottom 3 select from mask
|
||||
bot3 := U**0x07
|
||||
@for_vec_overlap{vl} (j to xn) {
|
||||
xv := load{*U~~(x+j), 0}
|
||||
xb := xv & bot3
|
||||
xt := shr16{xv &~ bot3, 3}
|
||||
ind := sel_c{xt} - vpopc{sel_b{xt} & U~~sel_m{xb}}
|
||||
if (dup) {
|
||||
i0 := V~~ind # Can contain -1
|
||||
def isel{u} = sel{H, u, i0}
|
||||
ind = isel{ui}
|
||||
if (nu > 16) {
|
||||
b := V~~(to_el{u16, i0} << (7 - lb{vl/2}))
|
||||
ind = topBlend{ind, isel{ui1}, b}
|
||||
if (nu > 32) ind = homBlend{topBlend{...each{isel,ui2}, b}, ind, i0 < V**vl}
|
||||
}
|
||||
}
|
||||
store{*U~~(*T~~rp+j), 0, ind}
|
||||
}
|
||||
goto{done}
|
||||
setlabel{no_bittab}
|
||||
}
|
||||
|
||||
# Binary search within vector registers
|
||||
def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
|
||||
assert{wn > 1}; assert{wn < maxwn}
|
||||
def wd = width{T}
|
||||
def I = if (wd<32) u8 else u32; def wi = width{I}
|
||||
def lanes = hasarch{'AVX2'} & (I==u8)
|
||||
def isub = wd/wi; def bb = bind{base,1<<wi}
|
||||
def vl = 256/wd; def svl = vl>>lanes
|
||||
def V = [vl]T
|
||||
def U = [vl](ty_u{T})
|
||||
def lt = if (up) <; else >
|
||||
# Number of steps
|
||||
log := ceil_log2{wn+1}
|
||||
gap := 1<<log - cast_i{u8, wn}
|
||||
# Fill with minimum value at the beginning
|
||||
def pre = (if (up) minvalue else maxvalue){T}
|
||||
wg := *V~~(w-gap)
|
||||
wv0:= homBlend{load{wg}, V**pre, maskOf{V,gap}}
|
||||
# For multiple lanes, interleave like transpose
|
||||
def maxstep = lb{maxwn}
|
||||
def lstep = lb{svl}
|
||||
def ex = maxstep - lstep
|
||||
wv := if (lanes) wv0 else tup{wv0,wv0}
|
||||
wv2 := wv # Compiler complains if uninitialized
|
||||
if (ex>=1 and wn >= svl) {
|
||||
--gap # Allows subtracting < instead of adding <=
|
||||
def un = uninterleave
|
||||
def tr_half{a, b} = each{bind{shufHalves,a,b}, tup{16b20, 16b31}}
|
||||
def un{{a,b}} = tr_half{un{a},un{b}}
|
||||
if (not lanes) tupsel{1,wv} = load{wg, 1}
|
||||
wv = un{wv}
|
||||
if (ex>=2 and wn >= 2*svl) {
|
||||
assert{lanes} # Different transpose pattern needed
|
||||
gap -= 2
|
||||
tup{wv, wv2} = each{un, tr_half{wv, un{load{wg, 1}}}}
|
||||
}
|
||||
}
|
||||
def ms{v}{h} = getsel{to_el{I, if (lanes) half{v,h} else tupsel{h,v}}}
|
||||
def selw = ms{wv}{0}; def selw1 = if (ex>=1) ms{wv}{1} else 'undef'
|
||||
def selw2 = if (ex>=2) each{ms{wv2}, iota{2}} else 'undef'
|
||||
# Offset at end
|
||||
off := U~~V**i8~~(gap - 1)
|
||||
# Midpoint bits for each step
|
||||
def lowbits = bb{copy{isub,isub}}
|
||||
bits := each{{j} => U**(lowbits << j), iota{lstep}}
|
||||
# Unroll sizes up to a full lane, handling extra lanes conditionally
|
||||
# in the largest one
|
||||
@unroll (klog from 2 to min{maxstep,lstep}+1) {
|
||||
def last = klog==lstep
|
||||
def this = if (not last) log==klog else log>=klog
|
||||
if (this) @for_vec_overlap{vl} (j to xn) {
|
||||
xv:= load{*V~~(x+j), 0}
|
||||
s := U**bb{iota{isub}} # Select sequential bytes within each U
|
||||
def ltx{se,ind} = lt{xv, V~~se{to_el{I,ind}}}
|
||||
@unroll (j to klog) {
|
||||
m := s | tupsel{klog-1-j,bits}
|
||||
s = homBlend{m, s, ltx{selw, m}}
|
||||
}
|
||||
r := if (isub==1) s else s>>(lb{isub}+wd-wi)
|
||||
# Extra selection lanes
|
||||
if (last and ex>=1 and log>=klog+1) {
|
||||
r += r
|
||||
c := ltx{selw1,s}
|
||||
if (ex>=2 and log>=klog+2) {
|
||||
r += r
|
||||
each{{se} => c += ltx{se,s}, selw2}
|
||||
}
|
||||
r += c
|
||||
}
|
||||
r -= off
|
||||
rn := if (T==i8) r
|
||||
else if (T==i16) half{narrow{u8, r}, 0}
|
||||
else extract{to_el{i64, narrow{u8, r}}, 0}
|
||||
store{*type{rn}~~(*i8~~rp+j), 0, rn}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def unroll_sizes = tup{4,1}
|
||||
fn write{T,k}(r:*void, i:u64, ...vs:k**u64) : void = {
|
||||
each{{j,v} => store{*T~~r, i+j, cast_i{T,v}}, iota{k}, vs}
|
||||
}
|
||||
def wr_arrs = each{{k} => rtype_arr{{T} => write{T,k}}, unroll_sizes}
|
||||
|
||||
def bin_search_branchless{up, w, wn, x, n, res, rtype} = {
|
||||
def lt = if (up) <; else >
|
||||
ws := w - 1
|
||||
l0 := wn + 1
|
||||
# Take a list of indices in x/res to allow unrolling
|
||||
def search{inds} = {
|
||||
xs:= each{bind{load,x}, inds} # Values
|
||||
ss:= each{{_}=>ws, inds} # Initial lower bound
|
||||
l := l0; h := undefined{u64} # Interval size l, same for all values
|
||||
while ((h=l/2) > 0) {
|
||||
# Branchless update
|
||||
def bin1{s, x, m} = { if (not lt{x, load{m}}) s = m }
|
||||
each{bin1, ss, xs, each{bind{+,h}, ss}}
|
||||
l -= h
|
||||
}
|
||||
each{{s} => u64~~(s - ws), ss}
|
||||
}
|
||||
# Unroll by 4 then 1
|
||||
def search{i, k} = search{each{bind{+,i}, iota{k}}}
|
||||
j:u64 = 0
|
||||
def searches{k, wr_arr} = {
|
||||
wr := load{wr_arr, rtype}
|
||||
while (j+k <= n) { wr(res, j, ...search{j, k}); j+=k }
|
||||
}
|
||||
each{searches, unroll_sizes, wr_arrs}
|
||||
}
|
||||
|
||||
fn bins{T, up}(w:*void, wn:u64, x:*void, xn:u64, rp:*void, rty:u8) : void = {
|
||||
def param = tup{up, *T~~w, wn, *T~~x, xn, rp}
|
||||
def lookup{k} = {
|
||||
if (rty == k) bins_lookup{tupsel{k,rtypes}, T, ...param}
|
||||
else if (k+1 < tuplen{rtypes}) lookup{k+1}
|
||||
}
|
||||
# For >=8 i8 values, vector bit-table is as good as binary search
|
||||
def wn_vec = if (T==i8) 8 else 2*256/width{T}
|
||||
if (hasarch{'AVX2'} and T<=i32 and wn < wn_vec and xn >= 256/width{T}) {
|
||||
bin_search_vec{T, ...param, wn_vec}
|
||||
# Lookup table threshold has to account for cost of
|
||||
# populating the table (proportional to wn until it's large), and
|
||||
# initializing the table (constant, much higher for i16)
|
||||
} else if (T==i8 and xn>=32 and (xn>=512 or xn >= wn>>6 + 32)) {
|
||||
lookup{0}
|
||||
} else if (T==i16 and xn>=512 and (xn>=1<<14 or xn >= wn>>6 + (u64~~3<<(12+rty))/promote{u64,ceil_log2{wn}+2})) {
|
||||
lookup{0}
|
||||
} else {
|
||||
bin_search_branchless{...param, rty}
|
||||
}
|
||||
}
|
||||
|
||||
exportT{
|
||||
'si_bins',
|
||||
join{table{bins, tup{i8,i16,i32,f64}, tup{1,0}}}
|
||||
}
|
||||
|
||||
# Utility for narrowing binary search right argument
|
||||
include './f64'
|
||||
require{'math.h'}
|
||||
fn saturate{F,T,...up}(dst:*void, src:*void, n:u64) : void = {
|
||||
# Auto-vectorizes, although not that well for f64
|
||||
def a = minvalue{T}; af := cast_i{F,a}
|
||||
def b = maxvalue{T}; bf := cast_i{F,b}
|
||||
@for (d in *T~~dst, xf in *F~~src over n) {
|
||||
x := if (F==f64) (if (tupsel{0,up}) floor else ceil){xf} else xf
|
||||
d = cast_i{T, x}
|
||||
if (x<af) d = a
|
||||
if (x>bf) d = b
|
||||
}
|
||||
}
|
||||
|
||||
exportT{
|
||||
'si_saturate',
|
||||
each{{a}=>saturate{...a}, merge{
|
||||
tup{tup{i16,i8}, tup{i32,i8}, tup{i32,i16}},
|
||||
join{table{bind{tup,f64}, tup{i8,i16,i32}, tup{1,0}}}
|
||||
}}
|
||||
}
|
||||
@ -31,3 +31,11 @@ def cbqn_elType{T & T==u16} = 6
|
||||
def cbqn_elType{T & T==u32} = 7
|
||||
|
||||
def cbqn_tyArrOffset{} = emit{u64, 'offsetof', 'TyArr', 'a'}
|
||||
|
||||
def talloc{T, len} = emit{*T, 'TALLOCP', fmt_type{T}, len}
|
||||
def tfree{ptr} = emit{void, 'TFREE', ptr}
|
||||
def fmt_type{T} = {
|
||||
def w = match (width{T}) { {_==8}=>'8'; {_==16}=>'16'; {_==32}=>'32'; {_==64}=>'64' }
|
||||
merge{quality{T}, w}
|
||||
}
|
||||
def fmt_type{T & isptr{T}} = merge{'*',fmt_type{eltype{T}}}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user