Merge pull request #44 from mlochbaum/rep

Replicate, Indices, Reshape, sorting
This commit is contained in:
dzaima 2022-09-24 19:21:22 +03:00 committed by GitHub
commit 4d42e19c27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 391 additions and 198 deletions

View File

@ -30,51 +30,72 @@
rp[j] = xi; \
}
#if SINGELI
extern void (*const avx2_scan_max8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_min8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_max16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
#define COUNT_THRESHOLD 32
#define WRITE_SPARSE_i8 \
for (usz i=0; i<n; i++) rp[i]=j; \
while (ij<n) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
GRADE_UD(avx2_scan_max8,avx2_scan_min8)(rp,rp,n);
#define WRITE_SPARSE_i16 \
usz b = 1<<10; \
for (usz k=0; ; ) { \
usz e = b<n-k? k+b : n; \
for (usz i=k; i<e; i++) rp[i]=j; \
while (ij<e) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
GRADE_UD(avx2_scan_max16,avx2_scan_min16)(rp+k,rp+k,e-k); \
if (e==n) {break;} k=e; \
}
#define WRITE_SPARSE(T) WRITE_SPARSE_##T
#else
#define COUNT_THRESHOLD 16
#define WRITE_SPARSE(T) \
for (usz i=0; i<n; i++) rp[i]=0; \
usz js = j; \
while (ij<n) { rp[ij]GRADE_UD(++,--); ij+=c0o[GRADE_UD(++j,--j)]; } \
for (usz i=0; i<n; i++) js=rp[i]+=js;
#endif
#define COUNTING_SORT(T) \
usz C=1<<(8*sizeof(T)); \
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
for (usz j=0; j<C; j++) c0[j]=0; \
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
if (n/16 <= C) { /* Sum-based */ \
for (usz i=0; i<n; i++) rp[i]=0; \
usz j=GRADE_UD(0,C-1), i; \
while ((i=c0[j])==0) GRADE_UD(j++,j--); \
usz js = j - C/2; \
while (i<n) { rp[i]++; i+=c0[GRADE_UD(++j,--j)]; } \
for (usz i=0; i<n; i++) js=rp[i]+=js; \
} else { /* Branchy */ \
FOR(j,C) for (usz c=c0[j]; c--; ) *rp++ = j-C/2; \
} \
usz C=1<<(8*sizeof(T)); \
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
for (usz j=0; j<C; j++) c0[j]=0; \
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
if (n/(COUNT_THRESHOLD*sizeof(T)) <= C) { /* Scan-based */ \
T j=GRADE_UD(-C/2,C/2-1); \
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
WRITE_SPARSE(T) \
} else { /* Branchy */ \
FOR(j,C) for (usz c=c0[j]; c--; ) *rp++ = j-C/2; \
} \
TFREE(c0)
// Radix sorting
#define PRE(T,K) usz p##K=s##K; s##K+=c##K[j]; c##K[j]=p##K
// 8-bit prefix sum by SWAR
#define PRE_UD(K,SL,SRE) \
u64 p##K=s##K; s##K+=((u64*)c##K)[j]; \
s##K+=s##K SL 8; s##K+=s##K SL 16; s##K+=s##K SL 32; \
((u64*)c##K)[j] = p##K|(s##K SL 8); s##K SRE 56
#define PRE64(K) GRADE_UD(PRE_UD(K,<<,>>=), PRE_UD(K,>>,<<=))
#define INC(P,I) GRADE_UD((P+1)[I]++,P[I]--)
#define ROFF GRADE_UD(1,0) // Radix offset
#define CHOOSE_SG_SORT(S,G) S
#define CHOOSE_SG_GRADE(S,G) G
#define RADIX_SORT_i8(T, TYP) \
TALLOC(T, c0, 256); T *c0o=c0+128; \
TALLOC(T, c0, 256+ROFF); T* c0o=c0+128; \
for (usz j=0; j<256; j++) c0[j]=0; \
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
RADIX_SUM_1_##T \
GRADE_UD(,c0[0]=n;) \
for (usz i=0; i<n; i++) INC(c0o,xp[i]); \
RADIX_SUM_1_##T; \
for (usz i=0; i<n; i++) { i8 xi=xp[i]; \
rp[c0o[xi]++]=CHOOSE_SG_##TYP(xi,i); } \
TFREE(c0)
#define RADIX_SUM_1_u8 { u64 s0=0; FOR(j, 256/8) { PRE64(0); } }
#define RADIX_SUM_1_usz { usz s0=0; FOR(j,256) { PRE(usz,0); } }
#define RADIX_SORT_i16(T, TYP, I) \
TALLOC(u8, alloc, 2*256*sizeof(T) + n*(2 + CHOOSE_SG_##TYP(0,sizeof(I)))); \
T *c0=(T*)alloc; T *c1=c0+256; T *c1o=c1+128; \
TALLOC(u8, alloc, (2*256+ROFF)*sizeof(T) + n*(2 + CHOOSE_SG_##TYP(0,sizeof(I)))); \
T* c0=(T*)alloc; T* c1=c0+256; T* c1o=c1+128; \
for (usz j=0; j<2*256; j++) c0[j]=0; \
for (usz i=0; i<n; i++) { i16 v=xp[i]; c0[(u8)v]++; c1o[(i8)(v>>8)]++; } \
c1[0]=GRADE_UD(-n,c0[0]=n); \
for (usz i=0; i<n; i++) { i16 v=xp[i]; INC(c0,(u8)v); INC(c1o,(i8)(v>>8)); } \
RADIX_SUM_2_##T; \
i16 *r0 = (i16*)(c0+2*256); \
CHOOSE_SG_##TYP( \
@ -86,18 +107,15 @@
for (usz i=0; i<n; i++) { i16 v=r0[i]; rp[c1o[(i8)(v>>8)]++]=g0[i]; } \
) \
TFREE(alloc)
#define RADIX_SUM_2_u8 u64 s0=0, s1=0; FOR(j,256/8) { PRE64(0); PRE64(1); }
#define RADIX_SUM_2(T) T s0=0, s1=0; FOR(j,256) { PRE(T,0); PRE(T,1); }
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
#define RADIX_SORT_i32(T, TYP, I) \
TALLOC(u8, alloc, 4*256*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \
TALLOC(u8, alloc, (4*256+ROFF)*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \
T *c0=(T*)alloc, *c1=c0+256, *c2=c1+256, *c3=c2+256, *c3o=c3+128; \
for (usz j=0; j<4*256; j++) c0[j]=0; \
c1[0]=c2[0]=c3[0]=GRADE_UD(-n,c0[0]=n); \
for (usz i=0; i<n; i++) { i32 v=xp[i]; \
c0 [(u8)v ]++; c1 [(u8)(v>> 8)]++; \
c2 [(u8)(v>>16)]++; c3o[(i8)(v>>24)]++; } \
INC(c0 ,(u8)v ); INC(c1 ,(u8)(v>> 8)); \
INC(c2 ,(u8)(v>>16)); INC(c3o,(i8)(v>>24)); } \
RADIX_SUM_4_##T; \
i32 *r0 = (i32*)(c0+4*256); \
CHOOSE_SG_##TYP( \
@ -113,10 +131,37 @@
for (usz i=0; i<n; i++) { i32 v=r0[i]; T c=c3o[(i8)(v>>24)]++; rp[c]=g0[i]; } \
) \
TFREE(alloc)
#define RADIX_SUM_4_u8 u64 s0=0, s1=0, s2=0, s3=0; FOR(j, 256/8) { PRE64(0); PRE64(1); PRE64(2); PRE64(3); }
#define RADIX_SUM_4(T) T s0=0, s1=0, s2=0, s3=0; FOR(j, 256) { PRE(u32,0); PRE(u32,1); PRE(u32,2); PRE(u32,3); }
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
#define PRE(K) s##K=c##K[j]+=s##K
#define RADIX_SUM_1(T) T s0=0; for(usz j=0;j<256;j++) { PRE(0); }
#define RADIX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); }
#define RADIX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); PRE(2); PRE(3); }
#if SINGELI
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0);
#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0);
#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0);
#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_u8 RADIX_SUM_1(u8)
#define RADIX_SUM_2_u8 RADIX_SUM_2(u8)
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
#define RADIX_SUM_4_u8 RADIX_SUM_4(u8)
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
#endif
#if SINGELI && !USZ_64
#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_usz RADIX_SUM_1(usz)
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
#endif
#define SORT_C1 CAT(GRADE_UD(and,or),c1)
B SORT_C1(B t, B x) {
@ -140,9 +185,9 @@ B SORT_C1(B t, B x) {
} else if (xe==el_i8) {
i8* xp = i8any_ptr(x);
i8* rp; r = m_i8arrv(&rp, n);
if (n<16) {
if (n < 16) {
INSERTION_SORT(i8);
} else if (n<256) {
} else if (n < 256) {
RADIX_SORT_i8(u8, SORT);
} else {
COUNTING_SORT(i8);
@ -150,7 +195,7 @@ B SORT_C1(B t, B x) {
} else if (xe==el_i16) {
i16* xp = i16any_ptr(x);
i16* rp; r = m_i16arrv(&rp, n);
if (n < 24) {
if (n < 20) {
INSERTION_SORT(i16);
} else if (n < 256) {
RADIX_SORT_i16(u8, SORT,);
@ -162,7 +207,7 @@ B SORT_C1(B t, B x) {
} else if (xe==el_i32) {
i32* xp = i32any_ptr(x);
i32* rp; r = m_i32arrv(&rp, n);
if (n < 40) {
if (n < 32) {
INSERTION_SORT(i32);
} else if (n < 256) {
RADIX_SORT_i32(u8, SORT,);
@ -181,6 +226,10 @@ B SORT_C1(B t, B x) {
#undef SORT_C1
#undef INSERTION_SORT
#undef COUNTING_SORT
#if SINGELI
#undef WRITE_SPARSE_i8
#undef WRITE_SPARSE_i16
#endif
#define GRADE_CHR GRADE_UD("⍋","⍒")
@ -356,21 +405,23 @@ done:
#undef LT
#undef FOR
#undef PRE
#undef PRE_UD
#undef INC
#undef ROFF
#undef PRE64
#undef CHOOSE_SG_SORT
#undef CHOOSE_SG_GRADE
#undef RADIX_SORT_i8
#undef RADIX_SORT_i16
#undef RADIX_SORT_i32
#undef RADIX_SUM_1
#undef RADIX_SUM_2
#undef RADIX_SUM_4
#undef RADIX_SUM_1_u8
#undef RADIX_SUM_1_usz
#undef RADIX_SORT_i16
#undef RADIX_SUM_2_u8
#undef RADIX_SUM_2
#undef RADIX_SUM_2_usz
#undef RADIX_SUM_2_u32
#undef RADIX_SORT_i32
#undef RADIX_SUM_4_u8
#undef RADIX_SUM_4
#undef RADIX_SUM_4_usz
#undef RADIX_SUM_4_u32
#undef GRADE_CAT

View File

@ -104,6 +104,13 @@ static B truncReshape(B x, usz xia, usz nia, ur nr, ShArr* sh) { // consumes all
arr_shSetU(ra, nr, sh);
return r;
}
static void fill_words(void* rp, u64 v, u64 bytes) {
usz wds = bytes/8;
usz ext = bytes%8;
u64* p = rp;
for (usz i=0; i<wds; i++) p[i] = v;
if (ext) memcpy(p+wds, &v, ext);
}
B shape_c2(B t, B w, B x) {
usz xia = isArr(x)? IA(x) : 1;
usz nia = 1;
@ -181,63 +188,121 @@ B shape_c2(B t, B w, B x) {
decG(w);
}
B xf;
if (isAtm(x)) {
xf = asFill(inc(x));
// goes to unit
} else {
Arr* r;
if (isArr(x)) {
if (nia <= xia) {
return truncReshape(x, xia, nia, nr, sh);
} else {
xf = getFillQ(x);
if (xia<=1) {
if (RARE(xia==0)) {
thrM("⥊: Empty 𝕩 and non-empty result");
// if (noFill(xf)) thrM("⥊: No fill for empty array");
// dec(x);
// x = inc(xf);
} else {
B n = IGet(x,0);
decG(x);
x = n;
}
if (xia <= 1) {
if (RARE(xia == 0)) thrM("⥊: Empty 𝕩 and non-empty result");
B n = IGet(x,0);
decG(x);
x = n;
goto unit;
}
if (xia <= nia/2) x = any_squeeze(x);
MAKE_MUT(m, nia); mut_init(m, TI(x,elType));
MUTG_INIT(m);
i64 div = nia/xia;
i64 mod = nia%xia;
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
mut_copyG(m, div*xia, x, 0, mod);
u8 xl = arrTypeBitsLog(TY(x));
u8 xt = arrNewType(TY(x));
u8* rp;
u64 bi, bf; // Bytes present, bytes wanted
if (xl == 0) { // Bits
u64* rq; r = m_bitarrp(&rq, nia);
rp = (u8*)rq;
usz nw = BIT_N(nia);
u64* xp = bitarr_ptr(x);
u64 b = xia;
if (b % 8) {
if (b < 64) {
// Need to avoid calling bit_cpy with arguments <64 bits apart
u64 v = xp[0] & (~(u64)0 >> (64-b));
do { v |= v<<b; b*=2; } while (b%8 && b<64);
rq[0] = v;
if (b>64 && nia>64) rq[1] = v>>(64-b/2);
} else {
memcpy(rq, xp, (b+7)/8);
}
for (; b%8; b*=2) {
if (b>nw*32) {
if (b<nia) bit_cpy(rq, b, rq, 0, nia-b);
b = 64*nw; // Ensure bi>=bf since bf is rounded up
break;
}
bit_cpy(rq, b, rq, 0, b);
}
} else {
memcpy(rp, xp, b/8);
}
bi = b/8;
bf = 8*nw;
if (bi == 1) { memset(rp, rp[0], bf); bi=bf; }
} else {
if (TI(x,elType) == el_B) {
B xf = getFillQ(x);
MAKE_MUT(m, nia); mut_init(m, el_B);
MUTG_INIT(m);
i64 div = nia/xia;
i64 mod = nia%xia;
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
mut_copyG(m, div*xia, x, 0, mod);
decG(x);
Arr* ra = mut_fp(m);
arr_shSetU(ra, nr, sh);
return withFill(taga(ra), xf);
}
u8 xk = xl - 3;
rp = m_tyarrp(&r, 1<<xk, nia, xt);
bi = (u64)xia<<xk;
bf = (u64)nia<<xk;
memcpy(rp, tyany_ptr(x), bi);
}
decG(x);
Arr* ra = mut_fp(m);
arr_shSetU(ra, nr, sh);
return withFill(taga(ra), xf);
if (bi<=8 && !(bi & (bi-1))) {
// Divisor of 8: write words
usz b = bi*8;
u64 v = *(u64*)rp & (~(u64)0 >> (64-b));
while (b<64) { v |= v<<b; b*=2; }
fill_words(rp, v, bf);
} else {
// Double up to length l, then copy in blocks
u64 l = 1<<15; if (l>bf) l=bf;
for (; bi<=l/2; bi+=bi) memcpy(rp+bi, rp, bi);
u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi);
if (e<bf) memcpy(rp+e, rp, bf-e);
}
}
} else {
unit:
#define FILL(E,T,V) T* rp; r = m_##E##arrp(&rp,nia); fill_words(rp, V, (u64)nia*sizeof(T));
if (isF64(x)) {
i32 n = (i32)x.f;
if (RARE(n!=x.f)) {
FILL(f64,f64,x.u)
} else if (n==(i8)n) { // memset can be faster than writing words
u8 b = n;
i8* rp; u64 nb = nia;
if (b <= 1) { r = m_bitarrp((u64**)&rp,nia); nb = 8*BIT_N(nia); b=-b; }
else { r = m_i8arrp ( &rp,nia); }
memset(rp, b, nb);
} else {
if(n==(i16)n) { FILL(i16,i16,(u16)n*0x0001000100010001) }
else { FILL(i32,i32,(u32)n*0x0000000100000001) }
}
} else if (isC32(x)) {
u32 c = o2cG(x);
if (c==(u8 )c) { u8* rp; r = m_c8arrp(&rp,nia); memset(rp, c, nia); }
else if (c==(u16)c) { FILL(c16,u16,c*0x0001000100010001) }
else { FILL(c32,u32,c*0x0000000100000001) }
} else {
incBy(x, nia); // in addition with the existing reference, this covers the filled amount & asFill
r = m_fillarrp(nia);
if (sizeof(B)==8) fill_words(fillarr_ptr(r), x.u, (u64)nia*8);
else for (usz i = 0; i < nia; i++) fillarr_ptr(r)[i] = x;
fillarr_setFill(r, asFill(x));
}
#undef FILL
}
unit:
if (isF64(x)) { decA(xf);
i32 n = (i32)x.f;
if (RARE(n!=x.f)) { f64* rp; Arr* r = m_f64arrp(&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=x.f; return taga(r); }
else if(n==(n&1)) { Arr* r=n?allOnes(nia):allZeroes(nia); arr_shSetU(r,nr,sh); return taga(r); }
else if(n==(i8 )n) { i8* rp; Arr* r = m_i8arrp (&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=n ; return taga(r); }
else if(n==(i16)n) { i16* rp; Arr* r = m_i16arrp(&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=n ; return taga(r); }
else { i32* rp; Arr* r = m_i32arrp(&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=n ; return taga(r); }
}
if (isC32(x)) { decA(xf);
u32* rp; Arr* r = m_c32arrp(&rp, nia); arr_shSetU(r, nr, sh);
u32 c = o2cG(x);
for (u64 i = 0; i < nia; i++) rp[i] = c;
return taga(r);
}
Arr* r = m_fillarrp(nia); arr_shSetU(r, nr, sh);
B* rp = fillarr_ptr(r);
if (nia) incBy(x, nia-1);
else dec(x);
for (u64 i = 0; i < nia; i++) rp[i] = x;
fillarr_setFill(r, xf);
arr_shSetU(r,nr,sh);
return taga(r);
}

View File

@ -104,6 +104,17 @@
#endif
#endif
#if SINGELI
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
#define PLUS_SCAN(T) avx2_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
extern void (*const avx2_scan_max32)(int32_t* v0,int32_t* v1,uint64_t v2);
#else
#define PLUS_SCAN(T) for (usz i=k; i<e; i++) js=rp[i]+=js;
#endif
// Dense Where, still significantly worse than SIMD
// Assumes modifiable DST
#define WHERE_DENSE(SRC, DST, LEN, OFF) do { \
@ -437,6 +448,62 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
return r;
}
// Replicate using plus/max/xor-scan
#define SCAN_CORE(WV, UPD, SET, SCAN) \
usz b = 1<<10; \
for (usz k=0, j=0, ij=WV; ; ) { \
usz e = b<s-k? k+b : s; \
SET; for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { j++; UPD; ij+=WV; } \
SCAN; \
if (e==s) {break;} k=e; \
}
#define SUM_CORE(T, WV, PREP, INC) \
SCAN_CORE(WV, PREP; rp[ij]+=INC, , PLUS_SCAN(T))
#if SINGELI
#define IND_BY_SCAN \
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, avx2_scan_max32(rp+k,rp+k,e-k))
#else
#define IND_BY_SCAN usz js=0; SUM_CORE(i32, xp[j], , 1)
#endif
#define REP_BY_SCAN(T, WV) \
T* xp = xv; T* rp = rv; \
T js=xp[0], px=js; \
SUM_CORE(T, WV, T sx=px, (px=xp[j])-sx)
#define BOOL_REP_XOR_SCAN(WV) \
usz b = 1<<12; \
u64 xx=xp[0], xs=xx>>63, js=-(xx&1); xx^=xx<<1; \
for (usz k=0, j=0, ij=WV; ; ) { \
usz e = b<s-k? k+b : s; \
usz eb = (e-1)/64+1; \
for (usz i=k/64; i<eb; i++) rp[i]=0; \
while (ij<e) { \
xx>>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \
rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=WV; \
} \
for (usz i=k/64; i<eb; i++) js=-((rp[i]^=js)>>63); \
if (e==s) {break;} k=e; \
}
// Basic boolean loop with overwriting
#define BOOL_REP_OVER(WV, LEN) \
u64 ri=0, rc=0, xc=0; usz j=0; \
for (usz i = 0; i < LEN; i++) { \
u64 v = -(u64)bitp_get(xp,i); \
rc ^= (v^xc) << (ri%64); \
xc = v; \
ri += WV; usz e = ri/64; \
if (j < e) { \
rp[j++] = rc; \
while (j < e) rp[j++] = v; \
rc = v; \
} \
} \
if (ri%64) rp[j] = rc;
extern B rt_slash;
B slash_c1(B t, B x) {
if (RARE(isAtm(x)) || RARE(RNK(x)!=1)) thrF("/: Argument must have rank 1 (%H ≡ ≢𝕩)", x);
@ -465,17 +532,8 @@ B slash_c1(B t, B x) {
for (u64 j = 0; j < c; j++) *rp++ = i;
}
} else {
if (s/16 <= xia) { // Sparse case: type of x matters
#define SPARSE_IND(T) \
T* xp = T##any_ptr(x); \
usz b = 1<<10; \
for (usz k=0, j=0, js=0, ij=xp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { rp[ij]++; ij+=xp[++j]; } \
for (usz i=k; i<e; i++) js=rp[i]+=js; \
if (e==s) {break;} k=e; \
}
if (s/32 <= xia) { // Sparse case: type of x matters
#define SPARSE_IND(T) T* xp = T##any_ptr(x); IND_BY_SCAN
i32* rp; r = m_i32arrv(&rp, s);
if (xe == el_i8 ) { SPARSE_IND(i8 ); }
else if (xe == el_i16) { SPARSE_IND(i16); }
@ -502,18 +560,30 @@ B slash_c1(B t, B x) {
}
B slash_c2(B t, B w, B x) {
B r;
if (isArr(w) && RNK(w)==1 && depth(w)==1) {
usz wia = IA(w);
i32 wv = -1;
usz wia;
if (isArr(w)) {
if (depth(w)>1) goto base;
ur wr = RNK(w);
if (wr>1) thrF("/: Simple 𝕨 must have rank 0 or 1 (%i≡=𝕨)", wr);
if (wr<1) { B v=IGet(w, 0); decG(w); w=v; goto atom; }
wia = IA(w);
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
if (isAtm(x) || RNK(x)==0) thrM("/: 𝕩 must have rank at least 1 for simple 𝕨");
ur xr = RNK(x);
usz xlen = *SH(x);
} else {
atom:
if (!q_i32(w)) goto base;
wv = o2i(w);
}
if (isAtm(x) || RNK(x)==0) thrM("/: 𝕩 must have rank at least 1 for simple 𝕨");
ur xr = RNK(x);
usz xlen = *SH(x);
u8 xl = cellWidthLog(x);
u8 xt = arrNewType(TY(x));
B r;
if (wv < 0) { // Array w
if (RARE(wia!=xlen)) thrF("/: Lengths of components of 𝕨 must match 𝕩 (%s ≠ %s)", wia, xlen);
u8 xl = cellWidthLog(x);
u8 xt = arrNewType(TY(x));
u8 we = TI(w,elType);
if (!elInt(we)) {
w=any_squeeze(w); we=TI(w,elType);
@ -555,30 +625,16 @@ B slash_c2(B t, B w, B x) {
// Make shape if needed; all cases below use it
usz* rsh = NULL;
if (xr > 1) {
usz* sh = rsh = m_shArr(xr)->a;
sh[0] = s;
shcpy(sh+1, SH(x)+1, xr-1);
rsh = m_shArr(xr)->a;
rsh[0] = s;
shcpy(rsh+1, SH(x)+1, xr-1);
}
if (xl == 0) {
u64* xp = bitarr_ptr(x);
u64* rp; r = m_bitarrv(&rp, s); if (rsh) { SPRNK(a(r),xr); SH(r) = rsh; }
if (s/256 <= wia) {
#define SPARSE_REP(T) \
T* wp = T##any_ptr(w); \
usz b = 1<<12; \
u64 xx=xp[0], xs=xx>>63, js=-(xx&1); xx^=xx<<1; \
for (usz k=0, j=0, ij=wp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
usz eb = (e-1)/64+1; \
for (usz i=k/64; i<eb; i++) rp[i]=0; \
while (ij<e) { \
xx>>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \
rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=wp[j]; \
} \
for (usz i=k/64; i<eb; i++) js=-((rp[i]^=js)>>63); \
if (e==s) {break;} k=e; \
}
if (s/1024 <= wia) {
#define SPARSE_REP(T) T* wp=T##any_ptr(w); BOOL_REP_XOR_SCAN(wp[j])
if (we==el_i8 ) { SPARSE_REP(i8 ); }
else if (we==el_i16) { SPARSE_REP(i16); }
else { SPARSE_REP(i32); }
@ -586,37 +642,15 @@ B slash_c2(B t, B w, B x) {
} else {
if (we < el_i32) w = taga(cpyI32Arr(w));
i32* wp = i32any_ptr(w);
u64 ri=0, rc=0, xc=0; usz j=0;
for (usz i = 0; i < wia; i++) {
u64 v = -(u64)bitp_get(xp,i);
rc ^= (v^xc) << (ri%64);
xc = v;
ri += wp[i]; usz e = ri/64;
if (j < e) {
rp[j++] = rc;
while (j < e) rp[j++] = v;
rc = v;
}
}
if (ri%64) rp[j] = rc;
BOOL_REP_OVER(wp[i], wia)
}
} else {
u8 xk = xl-3;
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
if (rsh) { Arr* ra=a(r); SPRNK(ra,xr); PSH(ra) = rsh; PIA(ra) = s*arr_csz(x); }
void* xv = tyany_ptr(x);
if (s/32 <= wia) { // Sparse case: use both types
#define CASE(L,XT) case L: { \
XT* xp = xv; XT* rp = rv; \
usz b = 1<<10; \
XT js=xp[0], px=js; \
for (usz k=0, j=0, ij=wp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { j++; XT sx=px; rp[ij]^=sx^(px=xp[j]); ij+=wp[j]; } \
for (usz i=k; i<e; i++) js=rp[i]^=js; \
if (e==s) {break;} k=e; \
} break; }
if ((xk<3? s/64 : s/32) <= wia) { // Sparse case: use both types
#define CASE(L,XT) case L: { REP_BY_SCAN(XT, wp[j]) break; }
#define SPARSE_REP(WT) \
WT* wp = WT##any_ptr(w); \
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
@ -640,33 +674,47 @@ B slash_c2(B t, B w, B x) {
}
}
goto decWX_ret;
}
if (isArr(x) && RNK(x)==1 && q_i32(w)) {
usz xia = IA(x);
i32 wv = o2i(w);
if (wv<=0) {
if (wv<0) thrM("/: 𝕨 cannot be negative");
return taga(arr_shVec(TI(x,slice)(x, 0, 0)));
} else {
if (wv <= 1) {
if (wv < 0) thrM("/: 𝕨 cannot be negative");
return wv ? x : taga(arr_shVec(TI(x,slice)(x, 0, 0)));
}
if (TI(x,elType)==el_i32) {
i32* xp = i32any_ptr(x);
i32* rp; r = m_i32arrv(&rp, xia*wv);
for (usz i = 0; i < xia; i++) {
for (i64 j = 0; j < wv; j++) *rp++ = xp[i];
}
goto decX_ret;
} else {
if (xlen == 0) return x;
usz s = xlen * wv;
if (xl>6 || (xl<3 && xl!=0) || TI(x,elType)==el_B) {
if (xr != 1) goto base;
SLOW2("𝕨/𝕩", w, x);
B xf = getFillQ(x);
HArr_p r0 = m_harrUv(xia*wv);
HArr_p r0 = m_harrUv(s);
SGetU(x)
for (usz i = 0; i < xia; i++) {
for (usz i = 0; i < xlen; i++) {
B cx = incBy(GetU(x, i), wv);
for (i64 j = 0; j < wv; j++) *r0.a++ = cx;
}
r = withFill(r0.b, xf);
goto decX_ret;
}
if (xl == 0) {
u64* xp = bitarr_ptr(x);
u64* rp; r = m_bitarrv(&rp, s);
if (wv <= 256) { BOOL_REP_XOR_SCAN(wv) }
else { BOOL_REP_OVER(wv, xlen) }
goto decX_ret;
} else {
u8 xk = xl-3;
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
void* xv = tyany_ptr(x);
#define CASE(L,T) case L: { REP_BY_SCAN(T, wv) break; }
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
#undef CASE
}
if (xr > 1) {
usz* rsh = m_shArr(xr)->a;
rsh[0] = s;
shcpy(rsh+1, SH(x)+1, xr-1);
Arr* ra=a(r); SPRNK(ra,xr); PSH(ra)=rsh; PIA(ra)=s*arr_csz(x);
}
goto decX_ret;
}
base:
return c2(rt_slash, w, x);

View File

@ -10,20 +10,42 @@ def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def base{b,l} = { if (0==tuplen{l}) 0; else tupsel{0,l}+b*base{b,slice{l,1}} }
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
# Fill last 4 bytes with last element, in each lane
def spread{a:VT} = {
def w = width{eltype{VT}}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
def step = 256/width{T}
def V = [step]T
p:= broadcast{V, init}
xv:= *V ~~ x
rv:= *V ~~ r
e:= len/step
@for (xv, rv over e) rv = scan{xv,p}
q:= len & (step-1)
if (q) maskstoreF{rv, maskOf{V, q}, e, scan_last{load{xv,e}, p}}
}
def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
def last{v, p} = op{pre{v}, p}
def scan{v, p} = {
n:= last{v, p}
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
n
}
scan_loop{T, init, x, r, len, scan, last}
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
def w = width{T}
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def w = width{T}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
# Fill last 4 bytes with last element, in each lane
def spread{a} = {
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
# Prefix op on entire AVX register
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
@ -31,19 +53,7 @@ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
}
def step = 256/w
def V = [step]T
p:= broadcast{V, id}
xv:= *V ~~ x
rv:= *V ~~ r
e:= len/step
@for (xv, rv over e) {
n:= op{pre{xv}, p}
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
rv = n
}
q:= len & (step-1)
if (q) maskstoreF{rv, maskOf{V, q}, e, op{pre{load{xv,e}}, p}}
scan_post{T, id, x, r, len, op, pre}
}
def avx2_scan_idem{T, op} = {
def m = 1 << (width{T}-1)
@ -56,6 +66,25 @@ def avx2_scan_idem{T, op} = {
'avx2_scan_min32' = avx2_scan_idem{i32, min}
'avx2_scan_max32' = avx2_scan_idem{i32, max}
# Associative scan
avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
# Prefix op on entire AVX register
def pre{a} = {
# Within each lane, scan using shifts by powers of 2.
# Assumes identity is 0.
def w = width{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
}
scan_post{T, init, x, r, len, op, pre}
}
'avx2_scan_pluswrap_u8' = avx2_scan_assoc_0{u8 , +}
'avx2_scan_pluswrap_u16' = avx2_scan_assoc_0{u16, +}
'avx2_scan_pluswrap_u32' = avx2_scan_assoc_0{u32, +}
# Boolean cumulative sum
avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
rv:= *[8]u32~~r