Merge scan-based sparse Indices/Replicate code with macros

This commit is contained in:
Marshall Lochbaum 2022-09-22 22:09:56 -04:00
parent d647978c8f
commit 7f6cf06eea

View File

@ -448,6 +448,30 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
return r;
}
#define SCAN_CORE(WV, UPD, SET, SCAN) \
usz b = 1<<10; \
for (usz k=0, j=0, ij=WV; ; ) { \
usz e = b<s-k? k+b : s; \
SET; for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { j++; UPD; ij+=WV; } \
SCAN; \
if (e==s) {break;} k=e; \
}
#define SUM_CORE(T, WV, PREP, INC) \
SCAN_CORE(WV, PREP; rp[ij]+=INC, , PLUS_SCAN(T))
#if SINGELI
#define IND_BY_SCAN \
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, avx2_scan_max32(rp+k,rp+k,e-k))
#else
#define IND_BY_SCAN usz js=0; SUM_CORE(i32, xp[j], , 1)
#endif
#define REP_BY_SCAN(T, WV) \
T* xp = xv; T* rp = rv; \
T js=xp[0], px=js; \
SUM_CORE(T, WV, T sx=px, (px=xp[j])-sx)
extern B rt_slash;
B slash_c1(B t, B x) {
if (RARE(isAtm(x)) || RARE(RNK(x)!=1)) thrF("/: Argument must have rank 1 (%H ≡ ≢𝕩)", x);
@ -477,30 +501,7 @@ B slash_c1(B t, B x) {
}
} else {
if (s/32 <= xia) { // Sparse case: type of x matters
#if SINGELI
#define SPARSE_IND(T) \
T* xp = T##any_ptr(x); \
usz b = 1<<10; \
for (usz k=0, j=0, ij=xp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
rp[k]=j; \
while (ij<e) { rp[ij]=++j; ij+=xp[j]; } \
avx2_scan_max32(rp+k,rp+k,e-k); \
if (e==s) {break;} k=e; \
}
#else
#define SPARSE_IND(T) \
T* xp = T##any_ptr(x); \
usz b = 1<<10; \
for (usz k=0, j=0, js=0, ij=xp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { rp[ij]++; ij+=xp[++j]; } \
PLUS_SCAN(i32) \
if (e==s) {break;} k=e; \
}
#endif
#define SPARSE_IND(T) T* xp = T##any_ptr(x); IND_BY_SCAN
i32* rp; r = m_i32arrv(&rp, s);
if (xe == el_i8 ) { SPARSE_IND(i8 ); }
else if (xe == el_i16) { SPARSE_IND(i16); }
@ -643,17 +644,7 @@ B slash_c2(B t, B w, B x) {
if (rsh) { Arr* ra=a(r); SPRNK(ra,xr); PSH(ra) = rsh; PIA(ra) = s*arr_csz(x); }
void* xv = tyany_ptr(x);
if (s/64 <= wia) { // Sparse case: use both types
#define CASE(L,XT) case L: { \
XT* xp = xv; XT* rp = rv; \
usz b = 1<<10; \
XT js=xp[0], px=js; \
for (usz k=0, j=0, ij=wp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { j++; XT sx=px; rp[ij]+=(px=xp[j])-sx; ij+=wp[j]; } \
PLUS_SCAN(XT) \
if (e==s) {break;} k=e; \
} break; }
#define CASE(L,XT) case L: { REP_BY_SCAN(XT, wp[j]) break; }
#define SPARSE_REP(WT) \
WT* wp = WT##any_ptr(w); \
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
@ -706,24 +697,9 @@ B slash_c2(B t, B w, B x) {
Arr* ra=a(r); SPRNK(ra,xr); PSH(ra)=rsh; PIA(ra)=s*arr_csz(x);
}
void* xv = tyany_ptr(x);
#define CONST_REP(T) { \
T* xp = xv; T* rp = rv; \
usz b = 1<<10; \
T js=xp[0], px=js; \
for (usz k=0, j=0, ij=wv; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { j++; T sx=px; rp[ij]+=(px=xp[j])-sx; ij+=wv; } \
PLUS_SCAN(T) \
if (e==s) {break;} k=e; \
} goto decX_ret; }
switch (xk) { default: UD;
case 0: CONST_REP(u8 )
case 1: CONST_REP(u16)
case 2: CONST_REP(u32)
case 3: CONST_REP(u64)
}
#undef CONST_REP
#define CASE(L,T) case L: { REP_BY_SCAN(T, wv) goto decX_ret; }
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
#undef CASE
}
base:
return c2(rt_slash, w, x);