Use AVX2 counting for 1-byte counting sort
This commit is contained in:
parent
0e5b98c491
commit
9d7d330a03
@ -81,6 +81,9 @@ extern void (*const avx2_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
|
||||
if (e==n) {break;} k=e; \
|
||||
}
|
||||
#define WRITE_SPARSE(T) WRITE_SPARSE_##T
|
||||
extern i8 (*const avx2_count_i8)(usz*, i8*, u64, i8);
|
||||
#define SINGELI_COUNT_OR(T) \
|
||||
if (1==sizeof(T)) avx2_count_i8(c0o, (i8*)xp, n, -128); else
|
||||
#else
|
||||
#define COUNT_THRESHOLD 16
|
||||
#define WRITE_SPARSE(T) \
|
||||
@ -88,13 +91,14 @@ extern void (*const avx2_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
|
||||
usz js = j; \
|
||||
while (ij<n) { rp[ij]GRADE_UD(++,--); ij+=c0o[GRADE_UD(++j,--j)]; } \
|
||||
for (usz i=0; i<n; i++) js=rp[i]+=js;
|
||||
#define SINGELI_COUNT_OR(T)
|
||||
#endif
|
||||
|
||||
#define COUNTING_SORT(T) \
|
||||
usz C=1<<(8*sizeof(T)); \
|
||||
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
SINGELI_COUNT_OR(T) for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
if (n/(COUNT_THRESHOLD*sizeof(T)) <= C) { /* Scan-based */ \
|
||||
T j=GRADE_UD(-C/2,C/2-1); \
|
||||
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
|
||||
@ -227,6 +231,7 @@ B SORT_C1(B t, B x) {
|
||||
#undef SORT_C1
|
||||
#undef INSERTION_SORT
|
||||
#undef COUNTING_SORT
|
||||
#undef SINGELI_COUNT_OR
|
||||
#if SINGELI_AVX2
|
||||
#undef WRITE_SPARSE_i8
|
||||
#undef WRITE_SPARSE_i16
|
||||
|
||||
@ -922,7 +922,7 @@ B slash_im(B t, B x) {
|
||||
#define SINGELI_COUNT_OR(N) if (N==8) { \
|
||||
TALLOC(usz, t, m/2); \
|
||||
for (usz j=0; j<m/2; j++) t[j]=0; \
|
||||
i8 max = avx2_count_i8(t, (i8*)xp, xia); \
|
||||
i8 max = avx2_count_i8(t, (i8*)xp, xia, 0); \
|
||||
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||
usz ria=max+1; \
|
||||
i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
||||
|
||||
@ -14,7 +14,7 @@ def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
|
||||
def inc{ptr, ind} = inc{ptr, ind, 1}
|
||||
|
||||
# Write counts /⁼x to tab and return ⌈´x
|
||||
fn count{T}(tab:*usz, x:*T, n:u64) : T = {
|
||||
fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
|
||||
def vbits = 256
|
||||
def vec = vbits/width{T}
|
||||
def uT = ty_u{T}
|
||||
@ -22,7 +22,7 @@ fn count{T}(tab:*usz, x:*T, n:u64) : T = {
|
||||
def block = (2048*8) / vbits # Target vectors per block
|
||||
def b_max = block + block/4 # Last block max length
|
||||
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
||||
mx:T = -1 # Maximum of x
|
||||
mx:T = min_allowed # Maximum of x
|
||||
i:u64 = 0
|
||||
while (i < n) {
|
||||
# Number of elements to handle in this iteration
|
||||
@ -35,22 +35,22 @@ fn count{T}(tab:*usz, x:*T, n:u64) : T = {
|
||||
xv := *V~~x
|
||||
jv := load{xv}; mv := jv
|
||||
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
|
||||
@for (x over _ from rv to r) { if (x<0) return{x}; if (x>mx) mx=x }
|
||||
@for (x over _ from rv to r) { if (x<min_allowed) return{x}; if (x>mx) mx=x }
|
||||
jt := fold{min, jv}
|
||||
mt := fold{max, mv}
|
||||
if (jt < 0) return{jt}
|
||||
if (jt < min_allowed) return{jt}
|
||||
if (mt > mx) mx = mt
|
||||
|
||||
nc := mt - jt # Number of counts to perform: last is implicit
|
||||
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
|
||||
if (nc <= 48) {
|
||||
r0 = rv
|
||||
j0 := promote{u64, uT~~jt} # Starting count
|
||||
m := promote{u64, uT~~nc} # Number of iterations
|
||||
m := promote{u64, nc} # Number of iterations
|
||||
total := trunc{usz, r0} # To compute last count
|
||||
def count_each{js, num} = {
|
||||
j := @collect (k to num) js+k
|
||||
j := @collect (k to num) trunc{T, js+k}
|
||||
c := copy{tuplen{j}, [vec]uT ** 0}
|
||||
e := each{{j}=>V**trunc{T, j}, j}
|
||||
e := each{{j}=>V**j, j}
|
||||
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
|
||||
def add_sum{c, j} = {
|
||||
s := promote{usz, sum_vec{V}(V~~c)}
|
||||
@ -61,7 +61,7 @@ fn count{T}(tab:*usz, x:*T, n:u64) : T = {
|
||||
m4 := m / 4
|
||||
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
||||
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
||||
inc{tab, j0 + m, trunc{usz,total}}
|
||||
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
|
||||
}
|
||||
|
||||
# Scalar fallback and cleanup
|
||||
|
||||
Loading…
Reference in New Issue
Block a user