Do general i8 and i16 /⁼ counts to i16 buffer, plus overflow list
This commit is contained in:
parent
fb5ee179cb
commit
0bbb335893
@ -97,9 +97,27 @@ extern void (*const si_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
|
||||
if (e==n) {break;} k=e; \
|
||||
}
|
||||
#define WRITE_SPARSE(T) WRITE_SPARSE_##T
|
||||
extern i8 (*const simd_count_i8)(usz*, i8*, u64, i8);
|
||||
#define SINGELI_COUNT_OR(T) \
|
||||
if (1==sizeof(T)) simd_count_i8(c0o, (i8*)xp, n, -128); else
|
||||
extern i8 (*const simd_count_i8)(u16*, u16*, void*, u64, i8);
|
||||
#define COUNTING_SORT_i8 \
|
||||
usz C=1<<8; \
|
||||
TALLOC(u16, c0, C+(n>>15)+1); \
|
||||
u16 *c0o=c0+C/2; u16 *ov=c0+C; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
simd_count_i8(c0o, ov, xp, n, -128); \
|
||||
if (n/COUNT_THRESHOLD <= C) { /* Scan-based */ \
|
||||
i8 j=GRADE_UD(-C/2,C/2-1); \
|
||||
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
|
||||
WRITE_SPARSE(i8) \
|
||||
TFREE(c0) \
|
||||
} else { /* Branchy, and ov may have entries */ \
|
||||
TALLOC(usz, cw, C); \
|
||||
NOUNROLL for (usz i=0; i<C; i++) cw[i]=c0[i]; \
|
||||
u16 oe=-1; \
|
||||
for (usz i=0; ov[i]!=oe; i++) cw[ov[i]]+= 1<<15; \
|
||||
TFREE(c0) \
|
||||
FOR(j,C) for (usz c=cw[j]; c--; ) *rp++ = j-C/2; \
|
||||
TFREE(cw) \
|
||||
}
|
||||
#else
|
||||
#define COUNT_THRESHOLD 16
|
||||
#define WRITE_SPARSE(T) \
|
||||
@ -107,14 +125,14 @@ extern i8 (*const simd_count_i8)(usz*, i8*, u64, i8);
|
||||
usz js = j; \
|
||||
while (ij<n) { rp[ij]GRADE_UD(++,--); ij+=c0o[GRADE_UD(++j,--j)]; } \
|
||||
for (usz i=0; i<n; i++) js=rp[i]+=js;
|
||||
#define SINGELI_COUNT_OR(T)
|
||||
#define COUNTING_SORT_i8 COUNTING_SORT(i8)
|
||||
#endif
|
||||
|
||||
#define COUNTING_SORT(T) \
|
||||
usz C=1<<(8*sizeof(T)); \
|
||||
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
SINGELI_COUNT_OR(T) for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
if (n/(COUNT_THRESHOLD*sizeof(T)) <= C) { /* Scan-based */ \
|
||||
T j=GRADE_UD(-C/2,C/2-1); \
|
||||
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
|
||||
@ -211,7 +229,7 @@ B SORT_C1(B t, B x) {
|
||||
} else if (n < 256) {
|
||||
RADIX_SORT_i8(u8, SORT);
|
||||
} else {
|
||||
COUNTING_SORT(i8);
|
||||
COUNTING_SORT_i8;
|
||||
}
|
||||
} else if (xe==el_i16) {
|
||||
i16* xp = i16any_ptr(x);
|
||||
@ -247,7 +265,7 @@ B SORT_C1(B t, B x) {
|
||||
#undef SORT_C1
|
||||
#undef INSERTION_SORT
|
||||
#undef COUNTING_SORT
|
||||
#undef SINGELI_COUNT_OR
|
||||
#undef COUNTING_SORT_i8
|
||||
#if SINGELI_AVX2
|
||||
#undef WRITE_SPARSE_i8
|
||||
#undef WRITE_SPARSE_i16
|
||||
|
||||
@ -813,6 +813,29 @@ B slash_c2(B t, B w, B x) {
|
||||
}
|
||||
|
||||
#if SINGELI_SIMD
|
||||
static B finish_small_count(B r, u16* ov) {
|
||||
// Need to add 1<<15 to r at i for each index i in ov
|
||||
u16 e = -1; // ov end marker
|
||||
if (*ov == e) {
|
||||
r = num_squeeze(r);
|
||||
} else {
|
||||
r = taga(cpyI32Arr(r)); i32* rp = tyany_ptr(r);
|
||||
usz on = 0; u16 ovi;
|
||||
for (usz i=0; (ovi=ov[i])!=e; i++) {
|
||||
i32 rv = (rp[ovi]+= 1<<15);
|
||||
if (RARE(rv < 0)) {
|
||||
rp[ovi] = rv ^ (1<<31);
|
||||
ov[on++] = ovi;
|
||||
}
|
||||
}
|
||||
if (RARE(on > 0)) { // Overflowed i32!
|
||||
r = taga(cpyF64Arr(r)); f64* rp = tyany_ptr(r);
|
||||
for (usz i=0; i<on; i++) rp[ov[i]]+= 1<<31;
|
||||
}
|
||||
FL_SET(r, fl_squoze);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
|
||||
// Overflow values in ov are sorted but not unique
|
||||
// Set mo to the greatest sum of oc for equal ov values
|
||||
@ -832,7 +855,7 @@ static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
|
||||
else if (mo < I32_MAX) { RESIZE(i32, I32) }
|
||||
else { RESIZE(f64, F64) }
|
||||
#undef RESIZE
|
||||
return r;
|
||||
return FL_SET(r, fl_squoze); // Relies on having checked for boolean
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -852,15 +875,19 @@ B slash_im(B t, B x) {
|
||||
r = num_squeeze(r); break;
|
||||
}
|
||||
#if SINGELI_SIMD
|
||||
#define INIT_RES(N,RIA) \
|
||||
i##N* rp; r = m_i##N##arrv(&rp, RIA); \
|
||||
for (usz i=0; i<RIA; i++) rp[i]=0;
|
||||
#define TRY_SMALL_OUT(N) \
|
||||
if (xp[0]<0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \
|
||||
u##N max=xp[a-1]; \
|
||||
usz rmax=xia; \
|
||||
if (a<xia) { \
|
||||
if (FL_HAS(x,fl_asc)) { \
|
||||
usz ria = xp[xia-1] + 1; \
|
||||
usz os = xia/128; \
|
||||
INIT_RES(8) \
|
||||
INIT_RES(8,ria) \
|
||||
TALLOC(usz, ov, 2*os); usz* oc = ov+os; \
|
||||
usz on = si_count_sorted_i##N((u8*)rp, ov, oc, xp, xia); \
|
||||
r = finish_sorted_count(r, ov, oc, on); \
|
||||
@ -877,56 +904,53 @@ B slash_im(B t, B x) {
|
||||
for (usz i=0; i<xia; i++) maxcount|=++tab[xp[i]]; \
|
||||
TFREE(tab); \
|
||||
if (maxcount<=1) a=xia; \
|
||||
else if (N>=16 && maxcount<128) { INIT_RES(8) FILL_RES break; } \
|
||||
else if (N>=16 && maxcount<128) rmax=127; \
|
||||
} \
|
||||
} \
|
||||
usz ria = (usz)max + 1; \
|
||||
if (a==xia) { /* Unique argument */ \
|
||||
usz ria = max + 1; \
|
||||
u64* rp; r = m_bitarrv(&rp, ria); \
|
||||
for (usz i=0; i<BIT_N(ria); i++) rp[i]=0; \
|
||||
for (usz i=0; i<xia; i++) bitp_set(rp, xp[i], 1); \
|
||||
break; \
|
||||
} \
|
||||
usz ria = (usz)max + 1;
|
||||
#define INIT_RES(N) \
|
||||
i##N* rp; r = m_i##N##arrv(&rp, ria); \
|
||||
for (usz i=0; i<ria; i++) rp[i]=0;
|
||||
#define FILL_RES \
|
||||
for (usz i = 0; i < xia; i++) rp[xp[i]]++;
|
||||
if (rmax<128) { /* xia<128 or xia<ria/8, fine to process x slowly */ \
|
||||
INIT_RES(8,ria) \
|
||||
for (usz i = 0; i < xia; i++) rp[xp[i]]++; \
|
||||
r = num_squeeze(r); break; \
|
||||
}
|
||||
#define CASE_SMALL(N) \
|
||||
case el_i##N: { \
|
||||
i##N* xp = i##N##any_ptr(x); \
|
||||
usz m=1<<N; \
|
||||
usz sa = m/2; \
|
||||
if (xia < sa || FL_HAS(x,fl_asc)) { \
|
||||
TRY_SMALL_OUT(N) \
|
||||
if (N==16 && ria<sa && ria+ria/2+64<=xia) { sa=ria; goto small_range##N; } \
|
||||
INIT_RES(N) FILL_RES \
|
||||
} else { \
|
||||
small_range##N: TALLOC(usz, t, sa); \
|
||||
for (usz j=0; j<sa; j++) t[j]=0; \
|
||||
i##N max = simd_count_i##N(t, xp, xia, 0); \
|
||||
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria=max+1; \
|
||||
i32* rp; r = m_i32arrv(&rp, ria); vfor (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
||||
TFREE(t); \
|
||||
} \
|
||||
r = num_squeeze(r); \
|
||||
break; \
|
||||
case el_i##N: { \
|
||||
i##N* xp = i##N##any_ptr(x); \
|
||||
usz sa = 1<<(N-1); \
|
||||
if (xia < sa || FL_HAS(x,fl_asc)) { \
|
||||
TRY_SMALL_OUT(N) \
|
||||
if (N==8) UD; \
|
||||
sa = ria; \
|
||||
} \
|
||||
INIT_RES(16,sa) \
|
||||
usz os = xia>>15; \
|
||||
TALLOC(u16, ov, os+1); \
|
||||
i##N max = simd_count_i##N((u16*)rp, (u16*)ov, xp, xia, 0); \
|
||||
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria = (usz)max + 1; \
|
||||
if (ria < sa) r = C2(take, m_f64(ria), r); \
|
||||
r = finish_small_count(r, ov); \
|
||||
TFREE(ov); \
|
||||
break; \
|
||||
}
|
||||
CASE_SMALL(8) CASE_SMALL(16)
|
||||
#undef CASE_SMALL
|
||||
case el_i32: {
|
||||
i32* xp = i32any_ptr(x);
|
||||
TRY_SMALL_OUT(32)
|
||||
INIT_RES(32)
|
||||
INIT_RES(32,ria)
|
||||
simd_count_i32_i32(rp, xp, xia);
|
||||
r = num_squeeze(r);
|
||||
break;
|
||||
}
|
||||
#undef TRY_SMALL_OUT
|
||||
#undef INIT_RES
|
||||
#undef FILL_RES
|
||||
#else
|
||||
#define CASE(N) case el_i##N: { \
|
||||
i##N* xp = i##N##any_ptr(x); \
|
||||
|
||||
@ -6,31 +6,24 @@ if_inline (hasarch{'SSE2'}) {
|
||||
def fold_addw{v:T=[_]E if E<=u32} = sum_vec{T}(v)
|
||||
}
|
||||
|
||||
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
|
||||
def inc{ptr:*T, ind, v} = store{ptr, ind, trunc{T,v} + load{ptr, ind}}
|
||||
def inc{ptr, ind} = inc{ptr, ind, 1}
|
||||
|
||||
def block_loop{V=[vec]T, n, iter} = {
|
||||
def block = (2048*8) / width{V} # Target vectors per block
|
||||
def b_max = block + block/4 # Last block max length
|
||||
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
||||
i:u64 = 0
|
||||
while (i < n) {
|
||||
# Number of elements to handle in this iteration
|
||||
r:u64 = n - i; if (r > vec*b_max) r = vec*block
|
||||
iter{r}
|
||||
i += r
|
||||
}
|
||||
}
|
||||
|
||||
# Write counts /⁼x to tab and return ⌈´x
|
||||
fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = {
|
||||
# Write counts (2⋆15)|/⁼x to tab, overflows to ov, and return ⌈´x
|
||||
fn count{T if T<=i16}(tab:*u16, ov:*u16, xp:*void, n:u64, min_allowed:T) : T = {
|
||||
def vbits = arch_defvw
|
||||
def vec = vbits/width{T}
|
||||
def uT = ty_u{T}
|
||||
def V = [vec]T
|
||||
def block = (2048*8) / vbits # Target vectors per block
|
||||
def b_max = block + block/4 # Last block max length
|
||||
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
||||
x := *T~~xp
|
||||
mx:T = min_allowed # Maximum of x
|
||||
block_loop{V, n, {r} => { # Handle r elements
|
||||
i:u64 = 0
|
||||
while (i < n) {
|
||||
# Number of elements to handle in this iteration
|
||||
r:u64 = n - i; if (r > vec*b_max) r = vec*block
|
||||
b := r / vec # Vector case does b full vectors if it runs
|
||||
rv:= b * vec
|
||||
r0:u64 = 0 # Elements actually handled by vector case
|
||||
@ -65,11 +58,37 @@ fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = {
|
||||
|
||||
# Scalar fallback and cleanup
|
||||
@for (x over _ from r0 to r) inc{tab, x}
|
||||
i += r
|
||||
x += r
|
||||
}}
|
||||
|
||||
# Keep counts below 1<<15 with the overflow list
|
||||
# Count from the end to include i==n and handle a long last block nicely
|
||||
if ((i-n)%(1<<15) < block*vec and i >= 1<<15) {
|
||||
ov += flush_counts(tab+min_allowed, ov, cast_i{usz,ty_u{mx+min_allowed}} + 1)
|
||||
}
|
||||
}
|
||||
store{ov, 0, maxvalue{u16}} # End marker: note x values fit in i16
|
||||
mx
|
||||
}
|
||||
|
||||
fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = {
|
||||
def vl = arch_defvw/16
|
||||
def V = [vl]u16
|
||||
def bot = 1<<15 - 1
|
||||
on:usz = 0
|
||||
@for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{topAny{t}}) {
|
||||
o := if (hasarch{'X86_64'}) topMask{t} else homMask{t > V**bot}
|
||||
if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1
|
||||
while (o > 0) {
|
||||
jv := jv*vl + cast_i{usz, ctz{o}}
|
||||
store{tab, jv, load{tab, jv} & bot}
|
||||
store{ov, on, trunc{u16, jv}}; ++on
|
||||
o &= o-1
|
||||
}
|
||||
}
|
||||
on
|
||||
}
|
||||
|
||||
# Sum comparisons against each value (except one) in the range
|
||||
def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
|
||||
total := trunc{usz, r0} # To compute last count
|
||||
@ -87,7 +106,7 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
|
||||
m4 := m / 4
|
||||
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
||||
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
||||
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
|
||||
inc{tab, trunc{T, j0 + m}, total}
|
||||
}
|
||||
|
||||
# Count adjacent equal elements at once, breaking at w-element groups
|
||||
@ -137,7 +156,7 @@ def inc_marked_runs{x, tab:*T, m, m0} = {
|
||||
jp:T = - T~~1
|
||||
while (m > m0) @unroll (2) {
|
||||
j := trunc{T, ctz{m}}
|
||||
inc{tab, load{x, j}, cast_i{T, j - jp}}
|
||||
inc{tab, load{x, j}, j - jp}
|
||||
jp = j; m &= m-1
|
||||
}
|
||||
# One step if popc{m} was odd, reducing branch mispredictions above
|
||||
|
||||
Loading…
Reference in New Issue
Block a user