Unify radix sort and radix lookup prefix sums

This commit is contained in:
Marshall Lochbaum 2022-10-22 16:29:05 -04:00
parent 2207d9f1bb
commit b9d5f10d4a
3 changed files with 43 additions and 51 deletions

View File

@ -74,6 +74,7 @@ extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
TFREE(c0)
// Radix sorting
#include "radix.h"
#define INC(P,I) GRADE_UD((P+1)[I]++,P[I]--)
#define ROFF GRADE_UD(1,0) // Radix offset
@ -81,7 +82,7 @@ extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
#define CHOOSE_SG_GRADE(S,G) G
#define RADIX_SORT_i8(T, TYP) \
TALLOC(T, c0, 256+ROFF); T* c0o=c0+128; \
TALLOC(T, c0, 256+ROFF); T* c0o=c0+128; \
for (usz j=0; j<256; j++) c0[j]=0; \
GRADE_UD(,c0[0]=n;) \
for (usz i=0; i<n; i++) INC(c0o,xp[i]); \
@ -132,37 +133,6 @@ extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
) \
TFREE(alloc)
#define PRE(K) s##K=c##K[j]+=s##K
#define RADIX_SUM_1(T) T s0=0; for(usz j=0;j<256;j++) { PRE(0); }
#define RADIX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); }
#define RADIX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); PRE(2); PRE(3); }
#if SINGELI
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0);
#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0);
#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0);
#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_u8 RADIX_SUM_1(u8)
#define RADIX_SUM_2_u8 RADIX_SUM_2(u8)
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
#define RADIX_SUM_4_u8 RADIX_SUM_4(u8)
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
#endif
#if SINGELI && !USZ_64
#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_usz RADIX_SUM_1(usz)
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
#endif
#define SORT_C1 CAT(GRADE_UD(and,or),c1)
B SORT_C1(B t, B x) {
if (isAtm(x) || RNK(x)==0) thrM(GRADE_UD("","")": Argument cannot have rank 0");
@ -413,17 +383,6 @@ done:
#undef RADIX_SORT_i8
#undef RADIX_SORT_i16
#undef RADIX_SORT_i32
#undef RADIX_SUM_1
#undef RADIX_SUM_2
#undef RADIX_SUM_4
#undef RADIX_SUM_1_u8
#undef RADIX_SUM_1_usz
#undef RADIX_SUM_2_u8
#undef RADIX_SUM_2_usz
#undef RADIX_SUM_2_u32
#undef RADIX_SUM_4_u8
#undef RADIX_SUM_4_usz
#undef RADIX_SUM_4_u32
#undef GRADE_CAT
#undef GRADE_NEG
#undef GRADE_UD

36
src/builtins/radix.h Normal file
View File

@ -0,0 +1,36 @@
#pragma once
// Radix sorting utilities
// These are leaky macros and assume counts are c0, c1,...
// which must be adjacent in memory
#define RDX_PRE(K) s##K=c##K[j]+=s##K
#define RDX_SUM_1(T) T s0=0; for(usz j=0;j<256;j++) { RDX_PRE(0); }
#define RDX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { RDX_PRE(0); RDX_PRE(1); }
#define RDX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { RDX_PRE(0); RDX_PRE(1); RDX_PRE(2); RDX_PRE(3); }
#if SINGELI
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0);
#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0);
#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0);
#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_u8 RDX_SUM_1(u8)
#define RADIX_SUM_2_u8 RDX_SUM_2(u8)
#define RADIX_SUM_2_u32 RDX_SUM_2(u32)
#define RADIX_SUM_4_u8 RDX_SUM_4(u8)
#define RADIX_SUM_4_u32 RDX_SUM_4(u32)
#endif
#if SINGELI && !USZ_64
#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_usz RDX_SUM_1(usz)
#define RADIX_SUM_2_usz RDX_SUM_2(usz)
#define RADIX_SUM_4_usz RDX_SUM_4(usz)
#endif

View File

@ -6,17 +6,14 @@
B not_c1(B t, B x);
B shape_c1(B t, B x);
#include "radix.h"
#define RADIX_LOOKUP_i32(INIT, SETTAB) \
/* Count keys */ \
for (usz j=0; j<2*rx; j++) c0[j] = 0; \
for (usz i=0; i<n; i++) { u32 v=v0[i]; c0[(u8)(v>>24)]++; c1[(u8)(v>>16)]++; } \
/* Exclusive prefix sum */ \
usz s0=0, s1=0; \
for (usz j=0; j<rx; j++) { \
usz p0 = s0, p1 = s1; \
s0 += c0[j]; s1 += c1[j]; \
c0[j] = p0; c1[j] = p1; \
} \
c1[0] = -n; \
for (usz i=0; i<n; i++) { u32 v=v0[i]; (c0+1)[(u8)(v>>24)]++; (c1+1)[(u8)(v>>16)]++; } \
/* Inclusive prefix sum; note c offsets above */ \
RADIX_SUM_2_u32; \
/* Radix moves */ \
for (usz i=0; i<n; i++) { u32 v=v0[i]; u8 k=k0[i]=(u8)(v>>24); usz c=c0[k]++; v1[c]=v; } \
for (usz i=0; i<n; i++) { u32 v=v1[i]; u8 k=k1[i]=(u8)(v>>16); usz c=c1[k]++; v2[c]=v; } \