Do general i8 and i16 /⁼ counts to i16 buffer, plus overflow list

This commit is contained in:
Marshall Lochbaum 2024-11-16 20:56:25 -05:00
parent fb5ee179cb
commit 0bbb335893
3 changed files with 119 additions and 58 deletions

View File

@ -97,9 +97,27 @@ extern void (*const si_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
if (e==n) {break;} k=e; \
}
#define WRITE_SPARSE(T) WRITE_SPARSE_##T
extern i8 (*const simd_count_i8)(usz*, i8*, u64, i8);
#define SINGELI_COUNT_OR(T) \
if (1==sizeof(T)) simd_count_i8(c0o, (i8*)xp, n, -128); else
extern i8 (*const simd_count_i8)(u16*, u16*, void*, u64, i8);
#define COUNTING_SORT_i8 \
usz C=1<<8; \
TALLOC(u16, c0, C+(n>>15)+1); \
u16 *c0o=c0+C/2; u16 *ov=c0+C; \
for (usz j=0; j<C; j++) c0[j]=0; \
simd_count_i8(c0o, ov, xp, n, -128); \
if (n/COUNT_THRESHOLD <= C) { /* Scan-based */ \
i8 j=GRADE_UD(-C/2,C/2-1); \
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
WRITE_SPARSE(i8) \
TFREE(c0) \
} else { /* Branchy, and ov may have entries */ \
TALLOC(usz, cw, C); \
NOUNROLL for (usz i=0; i<C; i++) cw[i]=c0[i]; \
u16 oe=-1; \
for (usz i=0; ov[i]!=oe; i++) cw[ov[i]]+= 1<<15; \
TFREE(c0) \
FOR(j,C) for (usz c=cw[j]; c--; ) *rp++ = j-C/2; \
TFREE(cw) \
}
#else
#define COUNT_THRESHOLD 16
#define WRITE_SPARSE(T) \
@ -107,14 +125,14 @@ extern i8 (*const simd_count_i8)(usz*, i8*, u64, i8);
usz js = j; \
while (ij<n) { rp[ij]GRADE_UD(++,--); ij+=c0o[GRADE_UD(++j,--j)]; } \
for (usz i=0; i<n; i++) js=rp[i]+=js;
#define SINGELI_COUNT_OR(T)
#define COUNTING_SORT_i8 COUNTING_SORT(i8)
#endif
#define COUNTING_SORT(T) \
usz C=1<<(8*sizeof(T)); \
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
for (usz j=0; j<C; j++) c0[j]=0; \
SINGELI_COUNT_OR(T) for (usz i=0; i<n; i++) c0o[xp[i]]++; \
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
if (n/(COUNT_THRESHOLD*sizeof(T)) <= C) { /* Scan-based */ \
T j=GRADE_UD(-C/2,C/2-1); \
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
@ -211,7 +229,7 @@ B SORT_C1(B t, B x) {
} else if (n < 256) {
RADIX_SORT_i8(u8, SORT);
} else {
COUNTING_SORT(i8);
COUNTING_SORT_i8;
}
} else if (xe==el_i16) {
i16* xp = i16any_ptr(x);
@ -247,7 +265,7 @@ B SORT_C1(B t, B x) {
#undef SORT_C1
#undef INSERTION_SORT
#undef COUNTING_SORT
#undef SINGELI_COUNT_OR
#undef COUNTING_SORT_i8
#if SINGELI_AVX2
#undef WRITE_SPARSE_i8
#undef WRITE_SPARSE_i16

View File

@ -813,6 +813,29 @@ B slash_c2(B t, B w, B x) {
}
#if SINGELI_SIMD
static B finish_small_count(B r, u16* ov) {
// Need to add 1<<15 to r at i for each index i in ov
u16 e = -1; // ov end marker
if (*ov == e) {
r = num_squeeze(r);
} else {
r = taga(cpyI32Arr(r)); i32* rp = tyany_ptr(r);
usz on = 0; u16 ovi;
for (usz i=0; (ovi=ov[i])!=e; i++) {
i32 rv = (rp[ovi]+= 1<<15);
if (RARE(rv < 0)) {
rp[ovi] = rv ^ (1<<31);
ov[on++] = ovi;
}
}
if (RARE(on > 0)) { // Overflowed i32!
r = taga(cpyF64Arr(r)); f64* rp = tyany_ptr(r);
for (usz i=0; i<on; i++) rp[ov[i]]+= 1<<31;
}
FL_SET(r, fl_squoze);
}
return r;
}
static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
// Overflow values in ov are sorted but not unique
// Set mo to the greatest sum of oc for equal ov values
@ -832,7 +855,7 @@ static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
else if (mo < I32_MAX) { RESIZE(i32, I32) }
else { RESIZE(f64, F64) }
#undef RESIZE
return r;
return FL_SET(r, fl_squoze); // Relies on having checked for boolean
}
#endif
@ -852,15 +875,19 @@ B slash_im(B t, B x) {
r = num_squeeze(r); break;
}
#if SINGELI_SIMD
#define INIT_RES(N,RIA) \
i##N* rp; r = m_i##N##arrv(&rp, RIA); \
for (usz i=0; i<RIA; i++) rp[i]=0;
#define TRY_SMALL_OUT(N) \
if (xp[0]<0) thrM("/⁼: Argument cannot contain negative numbers"); \
usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \
u##N max=xp[a-1]; \
usz rmax=xia; \
if (a<xia) { \
if (FL_HAS(x,fl_asc)) { \
usz ria = xp[xia-1] + 1; \
usz os = xia/128; \
INIT_RES(8) \
INIT_RES(8,ria) \
TALLOC(usz, ov, 2*os); usz* oc = ov+os; \
usz on = si_count_sorted_i##N((u8*)rp, ov, oc, xp, xia); \
r = finish_sorted_count(r, ov, oc, on); \
@ -877,56 +904,53 @@ B slash_im(B t, B x) {
for (usz i=0; i<xia; i++) maxcount|=++tab[xp[i]]; \
TFREE(tab); \
if (maxcount<=1) a=xia; \
else if (N>=16 && maxcount<128) { INIT_RES(8) FILL_RES break; } \
else if (N>=16 && maxcount<128) rmax=127; \
} \
} \
usz ria = (usz)max + 1; \
if (a==xia) { /* Unique argument */ \
usz ria = max + 1; \
u64* rp; r = m_bitarrv(&rp, ria); \
for (usz i=0; i<BIT_N(ria); i++) rp[i]=0; \
for (usz i=0; i<xia; i++) bitp_set(rp, xp[i], 1); \
break; \
} \
usz ria = (usz)max + 1;
#define INIT_RES(N) \
i##N* rp; r = m_i##N##arrv(&rp, ria); \
for (usz i=0; i<ria; i++) rp[i]=0;
#define FILL_RES \
for (usz i = 0; i < xia; i++) rp[xp[i]]++;
if (rmax<128) { /* xia<128 or xia<ria/8, fine to process x slowly */ \
INIT_RES(8,ria) \
for (usz i = 0; i < xia; i++) rp[xp[i]]++; \
r = num_squeeze(r); break; \
}
#define CASE_SMALL(N) \
case el_i##N: { \
i##N* xp = i##N##any_ptr(x); \
usz m=1<<N; \
usz sa = m/2; \
if (xia < sa || FL_HAS(x,fl_asc)) { \
TRY_SMALL_OUT(N) \
if (N==16 && ria<sa && ria+ria/2+64<=xia) { sa=ria; goto small_range##N; } \
INIT_RES(N) FILL_RES \
} else { \
small_range##N: TALLOC(usz, t, sa); \
for (usz j=0; j<sa; j++) t[j]=0; \
i##N max = simd_count_i##N(t, xp, xia, 0); \
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
usz ria=max+1; \
i32* rp; r = m_i32arrv(&rp, ria); vfor (usz i=0; i<ria; i++) rp[i]=t[i]; \
TFREE(t); \
} \
r = num_squeeze(r); \
break; \
case el_i##N: { \
i##N* xp = i##N##any_ptr(x); \
usz sa = 1<<(N-1); \
if (xia < sa || FL_HAS(x,fl_asc)) { \
TRY_SMALL_OUT(N) \
if (N==8) UD; \
sa = ria; \
} \
INIT_RES(16,sa) \
usz os = xia>>15; \
TALLOC(u16, ov, os+1); \
i##N max = simd_count_i##N((u16*)rp, (u16*)ov, xp, xia, 0); \
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
usz ria = (usz)max + 1; \
if (ria < sa) r = C2(take, m_f64(ria), r); \
r = finish_small_count(r, ov); \
TFREE(ov); \
break; \
}
CASE_SMALL(8) CASE_SMALL(16)
#undef CASE_SMALL
case el_i32: {
i32* xp = i32any_ptr(x);
TRY_SMALL_OUT(32)
INIT_RES(32)
INIT_RES(32,ria)
simd_count_i32_i32(rp, xp, xia);
r = num_squeeze(r);
break;
}
#undef TRY_SMALL_OUT
#undef INIT_RES
#undef FILL_RES
#else
#define CASE(N) case el_i##N: { \
i##N* xp = i##N##any_ptr(x); \

View File

@ -6,31 +6,24 @@ if_inline (hasarch{'SSE2'}) {
def fold_addw{v:T=[_]E if E<=u32} = sum_vec{T}(v)
}
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
def inc{ptr:*T, ind, v} = store{ptr, ind, trunc{T,v} + load{ptr, ind}}
def inc{ptr, ind} = inc{ptr, ind, 1}
def block_loop{V=[vec]T, n, iter} = {
def block = (2048*8) / width{V} # Target vectors per block
def b_max = block + block/4 # Last block max length
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
i:u64 = 0
while (i < n) {
# Number of elements to handle in this iteration
r:u64 = n - i; if (r > vec*b_max) r = vec*block
iter{r}
i += r
}
}
# Write counts /⁼x to tab and return ⌈´x
fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = {
# Write counts (2⋆15)|/⁼x to tab, overflows to ov, and return ⌈´x
fn count{T if T<=i16}(tab:*u16, ov:*u16, xp:*void, n:u64, min_allowed:T) : T = {
def vbits = arch_defvw
def vec = vbits/width{T}
def uT = ty_u{T}
def V = [vec]T
def block = (2048*8) / vbits # Target vectors per block
def b_max = block + block/4 # Last block max length
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
x := *T~~xp
mx:T = min_allowed # Maximum of x
block_loop{V, n, {r} => { # Handle r elements
i:u64 = 0
while (i < n) {
# Number of elements to handle in this iteration
r:u64 = n - i; if (r > vec*b_max) r = vec*block
b := r / vec # Vector case does b full vectors if it runs
rv:= b * vec
r0:u64 = 0 # Elements actually handled by vector case
@ -65,11 +58,37 @@ fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = {
# Scalar fallback and cleanup
@for (x over _ from r0 to r) inc{tab, x}
i += r
x += r
}}
# Keep counts below 1<<15 with the overflow list
# Count from the end to include i==n and handle a long last block nicely
if ((i-n)%(1<<15) < block*vec and i >= 1<<15) {
ov += flush_counts(tab+min_allowed, ov, cast_i{usz,ty_u{mx+min_allowed}} + 1)
}
}
store{ov, 0, maxvalue{u16}} # End marker: note x values fit in i16
mx
}
fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = {
def vl = arch_defvw/16
def V = [vl]u16
def bot = 1<<15 - 1
on:usz = 0
@for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{topAny{t}}) {
o := if (hasarch{'X86_64'}) topMask{t} else homMask{t > V**bot}
if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1
while (o > 0) {
jv := jv*vl + cast_i{usz, ctz{o}}
store{tab, jv, load{tab, jv} & bot}
store{ov, on, trunc{u16, jv}}; ++on
o &= o-1
}
}
on
}
# Sum comparisons against each value (except one) in the range
def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
total := trunc{usz, r0} # To compute last count
@ -87,7 +106,7 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
m4 := m / 4
@for (j4 to m4) count_each{j0 + 4*j4, 4}
@for (j from 4*m4 to m) count_each{j0 + j, 1}
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
inc{tab, trunc{T, j0 + m}, total}
}
# Count adjacent equal elements at once, breaking at w-element groups
@ -137,7 +156,7 @@ def inc_marked_runs{x, tab:*T, m, m0} = {
jp:T = - T~~1
while (m > m0) @unroll (2) {
j := trunc{T, ctz{m}}
inc{tab, load{x, j}, cast_i{T, j - jp}}
inc{tab, load{x, j}, j - jp}
jp = j; m &= m-1
}
# One step if popc{m} was odd, reducing branch mispredictions above