commit
b143892f21
@ -97,9 +97,27 @@ extern void (*const si_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
|
||||
if (e==n) {break;} k=e; \
|
||||
}
|
||||
#define WRITE_SPARSE(T) WRITE_SPARSE_##T
|
||||
extern i8 (*const avx2_count_i8)(usz*, i8*, u64, i8);
|
||||
#define SINGELI_COUNT_OR(T) \
|
||||
if (1==sizeof(T)) avx2_count_i8(c0o, (i8*)xp, n, -128); else
|
||||
extern i8 (*const simd_count_i8)(u16*, u16*, void*, u64, i8);
|
||||
#define COUNTING_SORT_i8 \
|
||||
usz C=1<<8; \
|
||||
TALLOC(u16, c0, C+(n>>15)+1); \
|
||||
u16 *c0o=c0+C/2; u16 *ov=c0+C; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
simd_count_i8(c0o, ov, xp, n, -128); \
|
||||
if (n/COUNT_THRESHOLD <= C) { /* Scan-based */ \
|
||||
i8 j=GRADE_UD(-C/2,C/2-1); \
|
||||
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
|
||||
WRITE_SPARSE(i8) \
|
||||
TFREE(c0) \
|
||||
} else { /* Branchy, and ov may have entries */ \
|
||||
TALLOC(usz, cw, C); \
|
||||
NOUNROLL for (usz i=0; i<C; i++) cw[i]=c0[i]; \
|
||||
u16 oe=-1; \
|
||||
for (usz i=0; ov[i]!=oe; i++) cw[ov[i]]+= 1<<15; \
|
||||
TFREE(c0) \
|
||||
FOR(j,C) for (usz c=cw[j]; c--; ) *rp++ = j-C/2; \
|
||||
TFREE(cw) \
|
||||
}
|
||||
#else
|
||||
#define COUNT_THRESHOLD 16
|
||||
#define WRITE_SPARSE(T) \
|
||||
@ -107,14 +125,14 @@ extern i8 (*const avx2_count_i8)(usz*, i8*, u64, i8);
|
||||
usz js = j; \
|
||||
while (ij<n) { rp[ij]GRADE_UD(++,--); ij+=c0o[GRADE_UD(++j,--j)]; } \
|
||||
for (usz i=0; i<n; i++) js=rp[i]+=js;
|
||||
#define SINGELI_COUNT_OR(T)
|
||||
#define COUNTING_SORT_i8 COUNTING_SORT(i8)
|
||||
#endif
|
||||
|
||||
#define COUNTING_SORT(T) \
|
||||
usz C=1<<(8*sizeof(T)); \
|
||||
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
SINGELI_COUNT_OR(T) for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
if (n/(COUNT_THRESHOLD*sizeof(T)) <= C) { /* Scan-based */ \
|
||||
T j=GRADE_UD(-C/2,C/2-1); \
|
||||
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
|
||||
@ -211,7 +229,7 @@ B SORT_C1(B t, B x) {
|
||||
} else if (n < 256) {
|
||||
RADIX_SORT_i8(u8, SORT);
|
||||
} else {
|
||||
COUNTING_SORT(i8);
|
||||
COUNTING_SORT_i8;
|
||||
}
|
||||
} else if (xe==el_i16) {
|
||||
i16* xp = i16any_ptr(x);
|
||||
@ -247,7 +265,7 @@ B SORT_C1(B t, B x) {
|
||||
#undef SORT_C1
|
||||
#undef INSERTION_SORT
|
||||
#undef COUNTING_SORT
|
||||
#undef SINGELI_COUNT_OR
|
||||
#undef COUNTING_SORT_i8
|
||||
#if SINGELI_AVX2
|
||||
#undef WRITE_SPARSE_i8
|
||||
#undef WRITE_SPARSE_i16
|
||||
|
||||
@ -54,14 +54,21 @@
|
||||
// Indices inverse (/⁼), a lot like Group
|
||||
// Always gives a squeezed result for integer 𝕩
|
||||
// Boolean 𝕩: just count 1s
|
||||
// Long i8 and i16 𝕩: count into zeroed buffer before anything else
|
||||
// Only zero positive part; if total is too small there were negatives
|
||||
// Cutoff is set so short 𝕩 gives a result of the same type
|
||||
// Without SINGELI_SIMD, just write to large-type table and squeeze
|
||||
// COULD do many /⁼ optimizations without SIMD
|
||||
// Scan for strictly ascending 𝕩
|
||||
// COULD vectorize with find-compare
|
||||
// Unsigned maximum for integers to avoid a separate negative check
|
||||
// If (≠÷⌈´)𝕩 is small, find result type with a sparse u8 table
|
||||
// COULD use a u16 table for i32 𝕩 to detect i16 result
|
||||
// SHOULD vectorize, maybe with find-compare
|
||||
// Sorted indices: i8 counter and index+count overflow
|
||||
// Work in blocks of 128, try galloping if one has start equal to end
|
||||
// Otherwise use runs-adaptive count (not sums, they're rarely better)
|
||||
// Long i8 and i16 𝕩: allocate full-range to skip initial range check
|
||||
// If (≠÷⌈´)𝕩 is small, detect u1 and i8 result with a sparse u8 table
|
||||
// General-case i8 to i32 𝕩: dedicated SIMD functions
|
||||
// i8 and i16 𝕩: i16 counter and index overflow (implicit count 1<<15)
|
||||
// Flush to overflow every 1<<15 writes
|
||||
// Get range in 2KB blocks to enable count by compare and sum
|
||||
// Run detection used partly to mitigate write stalls from repeats
|
||||
// COULD also alternate writes to multiple tables if 𝕩 is long enough
|
||||
|
||||
#include "../core.h"
|
||||
#include "../utils/mut.h"
|
||||
@ -812,12 +819,60 @@ B slash_c2(B t, B w, B x) {
|
||||
return c2rt(slash, w, x);
|
||||
}
|
||||
|
||||
#if SINGELI_SIMD
|
||||
static B finish_small_count(B r, u16* ov) {
|
||||
// Need to add 1<<15 to r at i for each index i in ov
|
||||
u16 e = -1; // ov end marker
|
||||
if (*ov == e) {
|
||||
r = num_squeeze(r);
|
||||
} else {
|
||||
r = taga(cpyI32Arr(r)); i32* rp = tyany_ptr(r);
|
||||
usz on = 0; u16 ovi;
|
||||
for (usz i=0; (ovi=ov[i])!=e; i++) {
|
||||
i32 rv = (rp[ovi]+= 1<<15);
|
||||
if (RARE(rv < 0)) {
|
||||
rp[ovi] = rv ^ (1<<31);
|
||||
ov[on++] = ovi;
|
||||
}
|
||||
}
|
||||
if (RARE(on > 0)) { // Overflowed i32!
|
||||
r = taga(cpyF64Arr(r)); f64* rp = tyany_ptr(r);
|
||||
for (usz i=0; i<on; i++) rp[ov[i]]+= 1U<<31;
|
||||
}
|
||||
FL_SET(r, fl_squoze);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
|
||||
// Overflow values in ov are sorted but not unique
|
||||
// Set mo to the greatest sum of oc for equal ov values
|
||||
usz mo = 0, pv = 0, c = 0;
|
||||
for (usz i=0; i<on; i++) {
|
||||
usz sv = pv; pv = ov[i];
|
||||
c = c*(sv==pv) + oc[i];
|
||||
if (c>mo) mo=c;
|
||||
}
|
||||
// Since mo is a multiple of 128 and all of r is less than 128,
|
||||
// values in r can't affect the result type
|
||||
#define RESIZE(T, UT) \
|
||||
r = taga(cpy##UT##Arr(r)); T* rp = tyany_ptr(r); \
|
||||
for (usz i=0; i<on; i++) rp[ov[i]]+= oc[i];
|
||||
if (mo == 0); // No overflow, r is correct already
|
||||
else if (mo < I16_MAX) { RESIZE(i16, I16) }
|
||||
else if (mo < I32_MAX) { RESIZE(i32, I32) }
|
||||
else { RESIZE(f64, F64) }
|
||||
#undef RESIZE
|
||||
return FL_SET(r, fl_squoze); // Relies on having checked for boolean
|
||||
}
|
||||
#endif
|
||||
|
||||
B slash_im(B t, B x) {
|
||||
if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be a list");
|
||||
u8 xe = TI(x,elType);
|
||||
usz xia = IA(x);
|
||||
if (xia==0) { decG(x); return emptyIVec(); }
|
||||
B r;
|
||||
retry:
|
||||
switch(xe) { default: UD;
|
||||
case el_bit: {
|
||||
usz sum = bit_sum(bitany_ptr(x), xia);
|
||||
@ -826,76 +881,100 @@ B slash_im(B t, B x) {
|
||||
rp[sum>0] = sum; rp[0] = xia - sum;
|
||||
r = num_squeeze(r); break;
|
||||
}
|
||||
#define IIND_INT(N) \
|
||||
#if SINGELI_SIMD
|
||||
#define INIT_RES(N,RIA) \
|
||||
i##N* rp; r = m_i##N##arrv(&rp, RIA); \
|
||||
for (usz i=0; i<RIA; i++) rp[i]=0;
|
||||
#define TRY_SMALL_OUT(N) \
|
||||
if (xp[0]<0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \
|
||||
u##N max=xp[a-1]; \
|
||||
usz rmax=xia; \
|
||||
if (a<xia) { \
|
||||
if (FL_HAS(x,fl_asc)) { \
|
||||
usz ria = xp[xia-1] + 1; \
|
||||
usz os = xia/128; \
|
||||
INIT_RES(8,ria) \
|
||||
TALLOC(usz, ov, 2*os); usz* oc = ov+os; \
|
||||
usz on = si_count_sorted_i##N((u8*)rp, ov, oc, xp, xia); \
|
||||
r = finish_sorted_count(r, ov, oc, on); \
|
||||
TFREE(ov); \
|
||||
break; \
|
||||
} \
|
||||
for (usz i=a; i<xia; i++) { u##N c=xp[i]; if (c>max) max=c; } \
|
||||
if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria = max + 1; \
|
||||
if (xia < ria/8) { \
|
||||
u8 maxcount = 0; \
|
||||
TALLOC(u8, tab, ria); \
|
||||
for (usz i=0; i<xia; i++) tab[xp[i]]=0; \
|
||||
for (usz i=0; i<xia; i++) maxcount|=tab[xp[i]]++; \
|
||||
for (usz i=0; i<xia; i++) tab[xp[i]]=0; \
|
||||
for (usz i=0; i<xia; i++) maxcount|=++tab[xp[i]]; \
|
||||
TFREE(tab); \
|
||||
if (maxcount==0) a=xia; \
|
||||
else if (N>=16 && maxcount<127) { \
|
||||
i8* rp; r = m_i8arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=0; \
|
||||
for (usz i = 0; i < xia; i++) rp[xp[i]]++; \
|
||||
break; \
|
||||
} \
|
||||
if (maxcount<=1) a=xia; \
|
||||
else if (N>=16 && maxcount<128) rmax=127; \
|
||||
} \
|
||||
} \
|
||||
usz ria = (usz)max + 1; \
|
||||
if (a==xia) { /* Unique argument */ \
|
||||
usz ria = max + 1; \
|
||||
u64* rp; r = m_bitarrv(&rp, ria); \
|
||||
for (usz i=0; i<BIT_N(ria); i++) rp[i]=0; \
|
||||
for (usz i=0; i<xia; i++) bitp_set(rp, xp[i], 1); \
|
||||
break; \
|
||||
} \
|
||||
usz ria = (usz)max + 1; \
|
||||
i##N* rp; r = m_i##N##arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=0; \
|
||||
for (usz i = 0; i < xia; i++) rp[xp[i]]++; \
|
||||
r = num_squeeze(r);
|
||||
#define CASE_SMALL(N) \
|
||||
case el_i##N: { \
|
||||
i##N* xp = i##N##any_ptr(x); \
|
||||
usz m=1<<N; \
|
||||
if (xia < m/2) { \
|
||||
IIND_INT(N) \
|
||||
} else SINGELI_COUNT_OR(N) { \
|
||||
TALLOC(usz, t, m); \
|
||||
for (usz j=0; j<m/2; j++) t[j]=0; \
|
||||
for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \
|
||||
t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \
|
||||
if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
i32* rp; r = m_i32arrv(&rp, ria); vfor (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
||||
TFREE(t); \
|
||||
r = num_squeeze(r); \
|
||||
} \
|
||||
break; \
|
||||
if (rmax<128) { /* xia<128 or xia<ria/8, fine to process x slowly */ \
|
||||
INIT_RES(8,ria) \
|
||||
for (usz i = 0; i < xia; i++) rp[xp[i]]++; \
|
||||
r = num_squeeze(r); break; \
|
||||
}
|
||||
#define CASE_SMALL(N) \
|
||||
case el_i##N: { \
|
||||
i##N* xp = i##N##any_ptr(x); \
|
||||
usz sa = 1<<(N-1); \
|
||||
if (xia < sa || FL_HAS(x,fl_asc)) { \
|
||||
TRY_SMALL_OUT(N) \
|
||||
assert(N != 8); \
|
||||
sa = ria; \
|
||||
} \
|
||||
INIT_RES(16,sa) \
|
||||
usz os = xia>>15; \
|
||||
TALLOC(u16, ov, os+1); \
|
||||
i##N max = simd_count_i##N((u16*)rp, (u16*)ov, xp, xia, 0); \
|
||||
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria = (usz)max + 1; \
|
||||
if (ria < sa) r = C2(take, m_f64(ria), r); \
|
||||
r = finish_small_count(r, ov); \
|
||||
TFREE(ov); \
|
||||
break; \
|
||||
}
|
||||
#if SINGELI_SIMD
|
||||
#define SINGELI_COUNT_OR(N) if (N==8) { \
|
||||
TALLOC(usz, t, m/2); \
|
||||
for (usz j=0; j<m/2; j++) t[j]=0; \
|
||||
i8 max = avx2_count_i8(t, (i8*)xp, xia, 0); \
|
||||
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria=max+1; \
|
||||
i32* rp; r = m_i32arrv(&rp, ria); vfor (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
||||
TFREE(t); \
|
||||
r = num_squeeze(r); \
|
||||
} else
|
||||
#else
|
||||
#define SINGELI_COUNT_OR(N)
|
||||
#endif
|
||||
CASE_SMALL(8) CASE_SMALL(16)
|
||||
#undef CASE_SMALL
|
||||
#undef SINGELI_COUNT_OR
|
||||
case el_i32: { i32* xp = i32any_ptr(x); IIND_INT(32) r = num_squeeze(r); break; }
|
||||
#undef IIND_INT
|
||||
#undef CASE_SMALL
|
||||
case el_i32: {
|
||||
i32* xp = i32any_ptr(x);
|
||||
TRY_SMALL_OUT(32)
|
||||
if (xia>I32_MAX) thrM("/⁼: Argument too large");
|
||||
INIT_RES(32,ria)
|
||||
simd_count_i32_i32(rp, xp, xia);
|
||||
r = num_squeeze(r); break;
|
||||
}
|
||||
#undef TRY_SMALL_OUT
|
||||
#undef INIT_RES
|
||||
#else
|
||||
#define CASE(N) case el_i##N: { \
|
||||
i##N* xp = i##N##any_ptr(x); \
|
||||
u##N max=xp[0]; \
|
||||
for (usz i=1; i<xia; i++) { u##N c=xp[i]; if (c>max) max=c; } \
|
||||
if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria = max + 1; \
|
||||
TALLOC(usz, t, ria); \
|
||||
for (usz j=0; j<ria; j++) t[j]=0; \
|
||||
for (usz i = 0; i < xia; i++) t[xp[i]]++; \
|
||||
if (xia<=I32_MAX) { i32* rp; r = m_i32arrv(&rp, ria); vfor (usz i=0; i<ria; i++) rp[i]=t[i]; } \
|
||||
else { f64* rp; r = m_f64arrv(&rp, ria); vfor (usz i=0; i<ria; i++) rp[i]=t[i]; } \
|
||||
TFREE(t); \
|
||||
r = num_squeeze(r); break; }
|
||||
CASE(8) CASE(16) CASE(32)
|
||||
#undef CASE
|
||||
#endif
|
||||
case el_f64: {
|
||||
f64* xp = f64any_ptr(x);
|
||||
usz i,j; f64 max=-1;
|
||||
@ -906,27 +985,19 @@ B slash_im(B t, B x) {
|
||||
u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; i<BIT_N(ria); i++) rp[i]=0;
|
||||
for (usz i = 0; i < xia; i++) bitp_set(rp, xp[i], 1);
|
||||
} else {
|
||||
if (xia>I32_MAX) thrM("/⁼: Argument too large");
|
||||
i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=0;
|
||||
for (usz i = 0; i < xia; i++) rp[(usz)xp[i]]++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case el_c8: case el_c16: case el_c32: case el_B: {
|
||||
SLOW1("/⁼", x);
|
||||
x = num_squeezeChk(x);
|
||||
xe = TI(x,elType);
|
||||
if (elNum(xe)) goto retry;
|
||||
B* xp = TO_BPTR(x);
|
||||
usz i,j; i64 max=-1;
|
||||
for (i = 0; i < xia; i++) { i64 c=o2i64(xp[i]); if (c<=max) break; max=c; }
|
||||
for (j = i; j < xia; j++) { i64 c=o2i64(xp[j]); max=c>max?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); }
|
||||
if (max > USZ_MAX-1) thrOOM();
|
||||
usz ria = max+1;
|
||||
if (i==xia) {
|
||||
u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; i<BIT_N(ria); i++) rp[i]=0;
|
||||
for (usz i = 0; i < xia; i++) bitp_set(rp, o2i64G(xp[i]), 1);
|
||||
} else {
|
||||
i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=0;
|
||||
for (usz i = 0; i < xia; i++) rp[o2i64G(xp[i])]++;
|
||||
}
|
||||
break;
|
||||
for (usz i=0; i<xia; i++) o2i64(xp[i]);
|
||||
UD;
|
||||
}
|
||||
}
|
||||
decG(x); return r;
|
||||
|
||||
@ -3,22 +3,23 @@ include './vecfold'
|
||||
|
||||
if_inline (hasarch{'SSE2'}) {
|
||||
fn sum_vec{T}(v:T) = vfold{+, fold{+, mzip128{v, T**0}}}
|
||||
def fold_addw{v:T=[_](u8)} = sum_vec{T}(v)
|
||||
def fold_addw{v:T=[_]E if E<=u32} = sum_vec{T}(v)
|
||||
}
|
||||
|
||||
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
|
||||
def inc{ptr:*T, ind, v} = store{ptr, ind, trunc{T,v} + load{ptr, ind}}
|
||||
def inc{ptr, ind} = inc{ptr, ind, 1}
|
||||
|
||||
# Write counts /⁼x to tab and return ⌈´x
|
||||
fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
|
||||
# Write counts (2⋆15)|/⁼x to tab, overflows to ov, and return ⌈´x
|
||||
fn count{T if T<=i16}(tab:*u16, ov:*u16, xp:*void, n:u64, min_allowed:T) : T = {
|
||||
def vbits = arch_defvw
|
||||
def vec = vbits/width{T}
|
||||
def uT = ty_u{T}
|
||||
def TU = ty_u{T}
|
||||
def V = [vec]T
|
||||
def block = (2048*8) / vbits # Target vectors per block
|
||||
def b_max = block + block/4 # Last block max length
|
||||
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
||||
mx:T = min_allowed # Maximum of x
|
||||
x := *T~~xp
|
||||
mx:T = min_allowed # Maximum of x
|
||||
i:u64 = 0
|
||||
while (i < n) {
|
||||
# Number of elements to handle in this iteration
|
||||
@ -28,44 +29,201 @@ fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
|
||||
r0:u64 = 0 # Elements actually handled by vector case
|
||||
|
||||
# Find range to check for suitability; return a negative if found
|
||||
# Also record number of differences dc
|
||||
# (double-counts at index vec but it doesn't need to be exact)
|
||||
xv := *V~~x
|
||||
jv := load{xv}; mv := jv
|
||||
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
|
||||
ne := jv != load{*V~~(x+1)}; dc := -ne
|
||||
# Quickly skip ahead if initial values are all equal
|
||||
a:u64 = 1
|
||||
if (not homAny{ne} and b>=4) {
|
||||
def eq_k{k} = homAll{tree_fold{&, @unroll(x in xv+a over k) x==jv}}
|
||||
def skip_eq{k} = if (eq_k{k}) { a=2*k; skip_eq{2*k} }
|
||||
def skip_eq{k==4} = while (a<=b-k and eq_k{k}) a+=k
|
||||
skip_eq{1}
|
||||
}
|
||||
# Now start analysis
|
||||
@for (xv, xp in *V~~(x-1) over _ from a to b) {
|
||||
jv = min{jv, xv}; mv = max{mv, xv}
|
||||
dc -= xp != xv
|
||||
}
|
||||
@for (x over _ from rv to r) { if (x<min_allowed) return{x}; if (x>mx) mx=x }
|
||||
jt := vfold{min, jv}
|
||||
mt := vfold{max, mv}
|
||||
if (jt < min_allowed) return{jt}
|
||||
if (mt > mx) mx = mt
|
||||
|
||||
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
|
||||
if (nc <= 24*vbits/128) {
|
||||
# Fast cases
|
||||
dt := promote{u64, fold_addw{dc}}
|
||||
nc := TU~~(mt - jt) # Number of counts to perform: last is implicit
|
||||
if (dt < b * (vec/2) and (b + dt)*4 < b * promote{u64,nc}) {
|
||||
r0 = count_with_runs{x, tab, r}
|
||||
} else if (nc <= 24*vbits/128) {
|
||||
r0 = rv
|
||||
j0 := promote{u64, uT~~jt} # Starting count
|
||||
m := promote{u64, nc} # Number of iterations
|
||||
total := trunc{usz, r0} # To compute last count
|
||||
def count_each{js, num} = {
|
||||
j := @collect (k to num) trunc{T, js+k}
|
||||
c := copy{length{j}, [vec]uT ** 0}
|
||||
e := each{{j}=>V**j, j}
|
||||
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
|
||||
def add_sum{c, j} = {
|
||||
s := promote{usz, fold_addw{c}}
|
||||
total -= s; inc{tab, j, s}
|
||||
}
|
||||
each{add_sum, c, j}
|
||||
count_by_sum{T, V, [vec]TU, xv, b, tab, r0,
|
||||
promote{u64, TU~~jt}, # Starting count
|
||||
promote{u64, nc} # Number of iterations
|
||||
}
|
||||
m4 := m / 4
|
||||
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
||||
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
||||
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
|
||||
}
|
||||
|
||||
# Scalar fallback and cleanup
|
||||
@for (x over _ from r0 to r) inc{tab, x}
|
||||
i += r
|
||||
x += r
|
||||
|
||||
# Keep counts below 1<<15 with the overflow list
|
||||
# Count from the end to include i==n and handle a long last block nicely
|
||||
if ((i-n)%(1<<15) < block*vec and i >= 1<<15) {
|
||||
ov += flush_counts(tab+min_allowed, ov, cast_i{usz,ty_u{mx+min_allowed}} + 1)
|
||||
}
|
||||
}
|
||||
store{ov, 0, maxvalue{u16}} # End marker: note x values fit in i16
|
||||
mx
|
||||
}
|
||||
|
||||
export{'avx2_count_i8', count{i8}}
|
||||
fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = {
|
||||
def vl = arch_defvw/16
|
||||
def V = [vl]u16
|
||||
def bot = 1<<15 - 1
|
||||
on:usz = 0
|
||||
@for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{topAny{t}}) {
|
||||
o := if (hasarch{'X86_64'}) topMask{t} else homMask{t > V**bot}
|
||||
if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1
|
||||
while (o > 0) {
|
||||
jv := jv*vl + cast_i{usz, ctz{o}}
|
||||
store{tab, jv, load{tab, jv} & bot}
|
||||
store{ov, on, trunc{u16, jv}}; ++on
|
||||
o &= o-1
|
||||
}
|
||||
}
|
||||
on
|
||||
}
|
||||
|
||||
# Sum comparisons against each value (except one) in the range
|
||||
def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
|
||||
total := trunc{usz, r0} # To compute last count
|
||||
def count_each{js, num} = {
|
||||
j := @collect (k to num) trunc{T, js+k}
|
||||
c := length{j} ** U**0
|
||||
e := each{{j}=>V**j, j}
|
||||
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
|
||||
def add_sum{c, j} = {
|
||||
s := promote{usz, fold_addw{c}}
|
||||
total -= s; inc{tab, j, s}
|
||||
}
|
||||
each{add_sum, c, j}
|
||||
}
|
||||
m4 := m / 4
|
||||
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
||||
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
||||
inc{tab, trunc{T, j0 + m}, total}
|
||||
}
|
||||
|
||||
# Count adjacent equal elements at once, breaking at w-element groups
|
||||
# May read up to index n from x, hitting one element that's not counted
|
||||
def count_with_runs{x, tab, n} = {
|
||||
def w = width{ux}
|
||||
m0:ux = 1 << (w-1) # Last element in each chunk ends a run
|
||||
bw := n / w
|
||||
@for (i to bw) {
|
||||
xo := x + i*w
|
||||
m := m0; mark_run_ends{xo, m}
|
||||
inc_marked_runs{xo, tab, m, m0}
|
||||
}
|
||||
bw * w # Number of elements handled
|
||||
}
|
||||
# Switch to the normal scalar count if there aren't enough runs
|
||||
def count_adapt_runs{x0, tab, n} = {
|
||||
def w = width{ux}
|
||||
m0:ux = 1 << (w-1)
|
||||
x := x0; r := n
|
||||
while (r > 0) {
|
||||
def skip_runs = makelabel{}
|
||||
b:usz = w
|
||||
if (rare{b > r}) { b = r; goto{skip_runs} }
|
||||
m := m0; mark_run_ends{x, m}
|
||||
if (popc{m} < w/2) {
|
||||
inc_marked_runs{x, tab, m, m0}
|
||||
} else {
|
||||
setlabel{skip_runs}
|
||||
@for (x over b) inc{tab, x}
|
||||
}
|
||||
x += b; r -= b
|
||||
}
|
||||
}
|
||||
def mark_run_ends{x:*T, m:(ux)} = {
|
||||
def vec = arch_defvw/width{T}
|
||||
def V = [vec]T
|
||||
@unroll (j to width{ux} / vec) {
|
||||
def jv = j*vec
|
||||
def lv{k} = load{*V~~(x + k)}
|
||||
m |= promote{ux, homMask{lv{jv} != lv{jv+1}}} << jv
|
||||
}
|
||||
}
|
||||
def inc_marked_runs{x, tab:*T, m, m0} = {
|
||||
def w = width{ux}
|
||||
# Iterate over runs marked in m
|
||||
jp:T = - T~~1
|
||||
while (m > m0) @unroll (2) {
|
||||
j := trunc{T, ctz{m}}
|
||||
inc{tab, load{x, j}, j - jp}
|
||||
jp = j; m &= m-1
|
||||
}
|
||||
# One step if popc{m} was odd, reducing branch mispredictions above
|
||||
inc{tab, load{x, w-1}, ((w-1) - jp) & -trunc{T, m>>(w-1)}}
|
||||
}
|
||||
|
||||
# No count_by_sum: build each run mask then decide whether to use it
|
||||
fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = count_adapt_runs{x, tab, n}
|
||||
|
||||
# For i←/⁼x, store r←128|i, and i-r sparsely: x is ∧(/r)∾oc/ov
|
||||
# ov is sorted but may not be unique, and oc contains multiples of 128
|
||||
# Return the shared length of ov and oc
|
||||
fn count_sorted{T}(r:*u8, ov:*usz, oc:*usz, x:*T, n:usz) : usz = {
|
||||
def V = [arch_defvw/width{T}]T
|
||||
def block = 128
|
||||
i:usz = 0
|
||||
on:usz = 0
|
||||
def overflow{xu,c} = { store{ov, on, xu}; store{oc, on, c}; ++on }
|
||||
while (i < n) {
|
||||
rem := n - i
|
||||
xo := x + i
|
||||
xi := load{xo}
|
||||
def overflow{c} = overflow{cast_i{usz,xi}, c}
|
||||
xe := xo-1; def bxi{j} = xi == load{xe, j}
|
||||
if (block <= rem and bxi{block}) {
|
||||
# Gallop to find last block ending in xi
|
||||
d:usz = block
|
||||
d2 := undefined{usz}
|
||||
while ((d2=d+d) <= rem and bxi{d2}) d = d2
|
||||
l := min{(rem &~ (block-1)) - d, d}
|
||||
# Target is in [d,d+l); shrink l
|
||||
while (l > block) {
|
||||
h := (l/2) &~ (block-1)
|
||||
m := d + h
|
||||
if (bxi{m}) d = m
|
||||
l -= h
|
||||
}
|
||||
overflow{d}
|
||||
rem -= d; if (rem == 0) return{on}
|
||||
i += d; xo += d; xi = load{xo}
|
||||
}
|
||||
# Count the next block normally
|
||||
rem = min{rem, usz~~block} # TODO get rid of the need of the usz~~ here
|
||||
count_adapt_runs{xo, r, rem}
|
||||
rxi := load{r, xi}
|
||||
if (rxi >= block) {
|
||||
store{r, xi, rxi - block}
|
||||
overflow{block}
|
||||
}
|
||||
i += rem
|
||||
}
|
||||
on
|
||||
}
|
||||
|
||||
export{'simd_count_i8', count{i8}}
|
||||
export{'simd_count_i16', count{i16}}
|
||||
export{'simd_count_i32_i32', count_i32_i32}
|
||||
export{'si_count_sorted_i8', count_sorted{i8}}
|
||||
export{'si_count_sorted_i16', count_sorted{i16}}
|
||||
export{'si_count_sorted_i32', count_sorted{i32}}
|
||||
|
||||
@ -15,6 +15,8 @@ def ntyp{S, ...S2, T if w128{T}} = merge{S, 'q', ...S2, '_', nty{T}}
|
||||
def ntyp{S, ...S2, T if w64{T}} = merge{S, ...S2, '_', nty{T}}
|
||||
def ntyp0{S, T} = merge{S, '_', nty{T}}
|
||||
|
||||
def __neg{a:T if nvecu{T}} = T~~(-ty_s{T}~~a)
|
||||
|
||||
def __lt{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vcltz', T}, a}
|
||||
def __le{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vclez', T}, a}
|
||||
def __gt{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vcgtz', T}, a}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user