Switch to forward inclusive sums for all radix sorting, and use Singeli

Removes SWAR for 8-bit counts, since reverse sorting uses negatives
This commit is contained in:
Marshall Lochbaum 2022-09-22 15:40:01 -04:00
parent 8e1d8bb42c
commit ec12dd4502

View File

@ -74,33 +74,28 @@ extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
TFREE(c0)
// Radix sorting
#define PRE(T,K) usz p##K=s##K; s##K+=c##K[j]; c##K[j]=p##K
// 8-bit prefix sum by SWAR
#define PRE_UD(K,SL,SRE) \
u64 p##K=s##K; s##K+=((u64*)c##K)[j]; \
s##K+=s##K SL 8; s##K+=s##K SL 16; s##K+=s##K SL 32; \
((u64*)c##K)[j] = p##K|(s##K SL 8); s##K SRE 56
#define PRE64(K) GRADE_UD(PRE_UD(K,<<,>>=), PRE_UD(K,>>,<<=))
#define INC(P,I) GRADE_UD((P+1)[I]++,P[I]--)
#define ROFF GRADE_UD(1,0) // Radix offset
#define CHOOSE_SG_SORT(S,G) S
#define CHOOSE_SG_GRADE(S,G) G
#define RADIX_SORT_i8(T, TYP) \
TALLOC(T, c0, 256); T *c0o=c0+128; \
TALLOC(T, c0, 256+ROFF); T* c0o=c0+128; \
for (usz j=0; j<256; j++) c0[j]=0; \
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
RADIX_SUM_1_##T \
GRADE_UD(,c0[0]=n;) \
for (usz i=0; i<n; i++) INC(c0o,xp[i]); \
RADIX_SUM_1_##T; \
for (usz i=0; i<n; i++) { i8 xi=xp[i]; \
rp[c0o[xi]++]=CHOOSE_SG_##TYP(xi,i); } \
TFREE(c0)
#define RADIX_SUM_1_u8 { u64 s0=0; FOR(j, 256/8) { PRE64(0); } }
#define RADIX_SUM_1_usz { usz s0=0; FOR(j,256) { PRE(usz,0); } }
#define RADIX_SORT_i16(T, TYP, I) \
TALLOC(u8, alloc, 2*256*sizeof(T) + n*(2 + CHOOSE_SG_##TYP(0,sizeof(I)))); \
T *c0=(T*)alloc; T *c1=c0+256; T *c1o=c1+128; \
TALLOC(u8, alloc, (2*256+ROFF)*sizeof(T) + n*(2 + CHOOSE_SG_##TYP(0,sizeof(I)))); \
T* c0=(T*)alloc; T* c1=c0+256; T* c1o=c1+128; \
for (usz j=0; j<2*256; j++) c0[j]=0; \
for (usz i=0; i<n; i++) { i16 v=xp[i]; c0[(u8)v]++; c1o[(i8)(v>>8)]++; } \
c1[0]=GRADE_UD(-n,c0[0]=n); \
for (usz i=0; i<n; i++) { i16 v=xp[i]; INC(c0,(u8)v); INC(c1o,(i8)(v>>8)); } \
RADIX_SUM_2_##T; \
i16 *r0 = (i16*)(c0+2*256); \
CHOOSE_SG_##TYP( \
@ -112,18 +107,15 @@ extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
for (usz i=0; i<n; i++) { i16 v=r0[i]; rp[c1o[(i8)(v>>8)]++]=g0[i]; } \
) \
TFREE(alloc)
#define RADIX_SUM_2_u8 u64 s0=0, s1=0; FOR(j,256/8) { PRE64(0); PRE64(1); }
#define RADIX_SUM_2(T) T s0=0, s1=0; FOR(j,256) { PRE(T,0); PRE(T,1); }
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
#define RADIX_SORT_i32(T, TYP, I) \
TALLOC(u8, alloc, 4*256*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \
TALLOC(u8, alloc, (4*256+ROFF)*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \
T *c0=(T*)alloc, *c1=c0+256, *c2=c1+256, *c3=c2+256, *c3o=c3+128; \
for (usz j=0; j<4*256; j++) c0[j]=0; \
c1[0]=c2[0]=c3[0]=GRADE_UD(-n,c0[0]=n); \
for (usz i=0; i<n; i++) { i32 v=xp[i]; \
c0 [(u8)v ]++; c1 [(u8)(v>> 8)]++; \
c2 [(u8)(v>>16)]++; c3o[(i8)(v>>24)]++; } \
INC(c0 ,(u8)v ); INC(c1 ,(u8)(v>> 8)); \
INC(c2 ,(u8)(v>>16)); INC(c3o,(i8)(v>>24)); } \
RADIX_SUM_4_##T; \
i32 *r0 = (i32*)(c0+4*256); \
CHOOSE_SG_##TYP( \
@ -139,10 +131,37 @@ extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
for (usz i=0; i<n; i++) { i32 v=r0[i]; T c=c3o[(i8)(v>>24)]++; rp[c]=g0[i]; } \
) \
TFREE(alloc)
#define RADIX_SUM_4_u8 u64 s0=0, s1=0, s2=0, s3=0; FOR(j, 256/8) { PRE64(0); PRE64(1); PRE64(2); PRE64(3); }
#define RADIX_SUM_4(T) T s0=0, s1=0, s2=0, s3=0; FOR(j, 256) { PRE(u32,0); PRE(u32,1); PRE(u32,2); PRE(u32,3); }
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
#define PRE(K) s##K=c##K[j]+=s##K
#define RADIX_SUM_1(T) T s0=0; for(usz j=0;j<256;j++) { PRE(0); }
#define RADIX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); }
#define RADIX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); PRE(2); PRE(3); }
#if SINGELI
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0);
#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0);
#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0);
#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_u8 RADIX_SUM_1(u8)
#define RADIX_SUM_2_u8 RADIX_SUM_2(u8)
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
#define RADIX_SUM_4_u8 RADIX_SUM_4(u8)
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
#endif
#if SINGELI && !USZ_64
#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_usz RADIX_SUM_1(usz)
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
#endif
#define SORT_C1 CAT(GRADE_UD(and,or),c1)
B SORT_C1(B t, B x) {
@ -166,9 +185,9 @@ B SORT_C1(B t, B x) {
} else if (xe==el_i8) {
i8* xp = i8any_ptr(x);
i8* rp; r = m_i8arrv(&rp, n);
if (n<16) {
if (n < 16) {
INSERTION_SORT(i8);
} else if (n<256) {
} else if (n < 256) {
RADIX_SORT_i8(u8, SORT);
} else {
COUNTING_SORT(i8);
@ -176,7 +195,7 @@ B SORT_C1(B t, B x) {
} else if (xe==el_i16) {
i16* xp = i16any_ptr(x);
i16* rp; r = m_i16arrv(&rp, n);
if (n < 24) {
if (n < 20) {
INSERTION_SORT(i16);
} else if (n < 256) {
RADIX_SORT_i16(u8, SORT,);
@ -188,7 +207,7 @@ B SORT_C1(B t, B x) {
} else if (xe==el_i32) {
i32* xp = i32any_ptr(x);
i32* rp; r = m_i32arrv(&rp, n);
if (n < 40) {
if (n < 32) {
INSERTION_SORT(i32);
} else if (n < 256) {
RADIX_SORT_i32(u8, SORT,);
@ -386,21 +405,23 @@ done:
#undef LT
#undef FOR
#undef PRE
#undef PRE_UD
#undef INC
#undef ROFF
#undef PRE64
#undef CHOOSE_SG_SORT
#undef CHOOSE_SG_GRADE
#undef RADIX_SORT_i8
#undef RADIX_SORT_i16
#undef RADIX_SORT_i32
#undef RADIX_SUM_1
#undef RADIX_SUM_2
#undef RADIX_SUM_4
#undef RADIX_SUM_1_u8
#undef RADIX_SUM_1_usz
#undef RADIX_SORT_i16
#undef RADIX_SUM_2_u8
#undef RADIX_SUM_2
#undef RADIX_SUM_2_usz
#undef RADIX_SUM_2_u32
#undef RADIX_SORT_i32
#undef RADIX_SUM_4_u8
#undef RADIX_SUM_4
#undef RADIX_SUM_4_usz
#undef RADIX_SUM_4_u32
#undef GRADE_CAT