Merge pull request #44 from mlochbaum/rep
Replicate, Indices, Reshape, sorting
This commit is contained in:
commit
4d42e19c27
@ -30,51 +30,72 @@
|
||||
rp[j] = xi; \
|
||||
}
|
||||
|
||||
#if SINGELI
|
||||
extern void (*const avx2_scan_max8)(int8_t* v0,int8_t* v1,uint64_t v2);
|
||||
extern void (*const avx2_scan_min8)(int8_t* v0,int8_t* v1,uint64_t v2);
|
||||
extern void (*const avx2_scan_max16)(int16_t* v0,int16_t* v1,uint64_t v2);
|
||||
extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
|
||||
#define COUNT_THRESHOLD 32
|
||||
#define WRITE_SPARSE_i8 \
|
||||
for (usz i=0; i<n; i++) rp[i]=j; \
|
||||
while (ij<n) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
|
||||
GRADE_UD(avx2_scan_max8,avx2_scan_min8)(rp,rp,n);
|
||||
#define WRITE_SPARSE_i16 \
|
||||
usz b = 1<<10; \
|
||||
for (usz k=0; ; ) { \
|
||||
usz e = b<n-k? k+b : n; \
|
||||
for (usz i=k; i<e; i++) rp[i]=j; \
|
||||
while (ij<e) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
|
||||
GRADE_UD(avx2_scan_max16,avx2_scan_min16)(rp+k,rp+k,e-k); \
|
||||
if (e==n) {break;} k=e; \
|
||||
}
|
||||
#define WRITE_SPARSE(T) WRITE_SPARSE_##T
|
||||
#else
|
||||
#define COUNT_THRESHOLD 16
|
||||
#define WRITE_SPARSE(T) \
|
||||
for (usz i=0; i<n; i++) rp[i]=0; \
|
||||
usz js = j; \
|
||||
while (ij<n) { rp[ij]GRADE_UD(++,--); ij+=c0o[GRADE_UD(++j,--j)]; } \
|
||||
for (usz i=0; i<n; i++) js=rp[i]+=js;
|
||||
#endif
|
||||
|
||||
#define COUNTING_SORT(T) \
|
||||
usz C=1<<(8*sizeof(T)); \
|
||||
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
if (n/16 <= C) { /* Sum-based */ \
|
||||
for (usz i=0; i<n; i++) rp[i]=0; \
|
||||
usz j=GRADE_UD(0,C-1), i; \
|
||||
while ((i=c0[j])==0) GRADE_UD(j++,j--); \
|
||||
usz js = j - C/2; \
|
||||
while (i<n) { rp[i]++; i+=c0[GRADE_UD(++j,--j)]; } \
|
||||
for (usz i=0; i<n; i++) js=rp[i]+=js; \
|
||||
} else { /* Branchy */ \
|
||||
FOR(j,C) for (usz c=c0[j]; c--; ) *rp++ = j-C/2; \
|
||||
} \
|
||||
usz C=1<<(8*sizeof(T)); \
|
||||
TALLOC(usz, c0, C); usz *c0o=c0+C/2; \
|
||||
for (usz j=0; j<C; j++) c0[j]=0; \
|
||||
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
if (n/(COUNT_THRESHOLD*sizeof(T)) <= C) { /* Scan-based */ \
|
||||
T j=GRADE_UD(-C/2,C/2-1); \
|
||||
usz ij; while ((ij=c0o[j])==0) GRADE_UD(j++,j--); \
|
||||
WRITE_SPARSE(T) \
|
||||
} else { /* Branchy */ \
|
||||
FOR(j,C) for (usz c=c0[j]; c--; ) *rp++ = j-C/2; \
|
||||
} \
|
||||
TFREE(c0)
|
||||
|
||||
// Radix sorting
|
||||
#define PRE(T,K) usz p##K=s##K; s##K+=c##K[j]; c##K[j]=p##K
|
||||
// 8-bit prefix sum by SWAR
|
||||
#define PRE_UD(K,SL,SRE) \
|
||||
u64 p##K=s##K; s##K+=((u64*)c##K)[j]; \
|
||||
s##K+=s##K SL 8; s##K+=s##K SL 16; s##K+=s##K SL 32; \
|
||||
((u64*)c##K)[j] = p##K|(s##K SL 8); s##K SRE 56
|
||||
#define PRE64(K) GRADE_UD(PRE_UD(K,<<,>>=), PRE_UD(K,>>,<<=))
|
||||
#define INC(P,I) GRADE_UD((P+1)[I]++,P[I]--)
|
||||
#define ROFF GRADE_UD(1,0) // Radix offset
|
||||
|
||||
#define CHOOSE_SG_SORT(S,G) S
|
||||
#define CHOOSE_SG_GRADE(S,G) G
|
||||
|
||||
#define RADIX_SORT_i8(T, TYP) \
|
||||
TALLOC(T, c0, 256); T *c0o=c0+128; \
|
||||
TALLOC(T, c0, 256+ROFF); T* c0o=c0+128; \
|
||||
for (usz j=0; j<256; j++) c0[j]=0; \
|
||||
for (usz i=0; i<n; i++) c0o[xp[i]]++; \
|
||||
RADIX_SUM_1_##T \
|
||||
GRADE_UD(,c0[0]=n;) \
|
||||
for (usz i=0; i<n; i++) INC(c0o,xp[i]); \
|
||||
RADIX_SUM_1_##T; \
|
||||
for (usz i=0; i<n; i++) { i8 xi=xp[i]; \
|
||||
rp[c0o[xi]++]=CHOOSE_SG_##TYP(xi,i); } \
|
||||
TFREE(c0)
|
||||
#define RADIX_SUM_1_u8 { u64 s0=0; FOR(j, 256/8) { PRE64(0); } }
|
||||
#define RADIX_SUM_1_usz { usz s0=0; FOR(j,256) { PRE(usz,0); } }
|
||||
|
||||
#define RADIX_SORT_i16(T, TYP, I) \
|
||||
TALLOC(u8, alloc, 2*256*sizeof(T) + n*(2 + CHOOSE_SG_##TYP(0,sizeof(I)))); \
|
||||
T *c0=(T*)alloc; T *c1=c0+256; T *c1o=c1+128; \
|
||||
TALLOC(u8, alloc, (2*256+ROFF)*sizeof(T) + n*(2 + CHOOSE_SG_##TYP(0,sizeof(I)))); \
|
||||
T* c0=(T*)alloc; T* c1=c0+256; T* c1o=c1+128; \
|
||||
for (usz j=0; j<2*256; j++) c0[j]=0; \
|
||||
for (usz i=0; i<n; i++) { i16 v=xp[i]; c0[(u8)v]++; c1o[(i8)(v>>8)]++; } \
|
||||
c1[0]=GRADE_UD(-n,c0[0]=n); \
|
||||
for (usz i=0; i<n; i++) { i16 v=xp[i]; INC(c0,(u8)v); INC(c1o,(i8)(v>>8)); } \
|
||||
RADIX_SUM_2_##T; \
|
||||
i16 *r0 = (i16*)(c0+2*256); \
|
||||
CHOOSE_SG_##TYP( \
|
||||
@ -86,18 +107,15 @@
|
||||
for (usz i=0; i<n; i++) { i16 v=r0[i]; rp[c1o[(i8)(v>>8)]++]=g0[i]; } \
|
||||
) \
|
||||
TFREE(alloc)
|
||||
#define RADIX_SUM_2_u8 u64 s0=0, s1=0; FOR(j,256/8) { PRE64(0); PRE64(1); }
|
||||
#define RADIX_SUM_2(T) T s0=0, s1=0; FOR(j,256) { PRE(T,0); PRE(T,1); }
|
||||
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
|
||||
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
|
||||
|
||||
#define RADIX_SORT_i32(T, TYP, I) \
|
||||
TALLOC(u8, alloc, 4*256*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \
|
||||
TALLOC(u8, alloc, (4*256+ROFF)*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \
|
||||
T *c0=(T*)alloc, *c1=c0+256, *c2=c1+256, *c3=c2+256, *c3o=c3+128; \
|
||||
for (usz j=0; j<4*256; j++) c0[j]=0; \
|
||||
c1[0]=c2[0]=c3[0]=GRADE_UD(-n,c0[0]=n); \
|
||||
for (usz i=0; i<n; i++) { i32 v=xp[i]; \
|
||||
c0 [(u8)v ]++; c1 [(u8)(v>> 8)]++; \
|
||||
c2 [(u8)(v>>16)]++; c3o[(i8)(v>>24)]++; } \
|
||||
INC(c0 ,(u8)v ); INC(c1 ,(u8)(v>> 8)); \
|
||||
INC(c2 ,(u8)(v>>16)); INC(c3o,(i8)(v>>24)); } \
|
||||
RADIX_SUM_4_##T; \
|
||||
i32 *r0 = (i32*)(c0+4*256); \
|
||||
CHOOSE_SG_##TYP( \
|
||||
@ -113,10 +131,37 @@
|
||||
for (usz i=0; i<n; i++) { i32 v=r0[i]; T c=c3o[(i8)(v>>24)]++; rp[c]=g0[i]; } \
|
||||
) \
|
||||
TFREE(alloc)
|
||||
#define RADIX_SUM_4_u8 u64 s0=0, s1=0, s2=0, s3=0; FOR(j, 256/8) { PRE64(0); PRE64(1); PRE64(2); PRE64(3); }
|
||||
#define RADIX_SUM_4(T) T s0=0, s1=0, s2=0, s3=0; FOR(j, 256) { PRE(u32,0); PRE(u32,1); PRE(u32,2); PRE(u32,3); }
|
||||
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
|
||||
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
|
||||
|
||||
#define PRE(K) s##K=c##K[j]+=s##K
|
||||
#define RADIX_SUM_1(T) T s0=0; for(usz j=0;j<256;j++) { PRE(0); }
|
||||
#define RADIX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); }
|
||||
#define RADIX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); PRE(2); PRE(3); }
|
||||
|
||||
#if SINGELI
|
||||
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
|
||||
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
|
||||
#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0);
|
||||
#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0);
|
||||
#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0);
|
||||
#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0);
|
||||
#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0);
|
||||
#else
|
||||
#define RADIX_SUM_1_u8 RADIX_SUM_1(u8)
|
||||
#define RADIX_SUM_2_u8 RADIX_SUM_2(u8)
|
||||
#define RADIX_SUM_2_u32 RADIX_SUM_2(u32)
|
||||
#define RADIX_SUM_4_u8 RADIX_SUM_4(u8)
|
||||
#define RADIX_SUM_4_u32 RADIX_SUM_4(u32)
|
||||
#endif
|
||||
|
||||
#if SINGELI && !USZ_64
|
||||
#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0);
|
||||
#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0);
|
||||
#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0);
|
||||
#else
|
||||
#define RADIX_SUM_1_usz RADIX_SUM_1(usz)
|
||||
#define RADIX_SUM_2_usz RADIX_SUM_2(usz)
|
||||
#define RADIX_SUM_4_usz RADIX_SUM_4(usz)
|
||||
#endif
|
||||
|
||||
#define SORT_C1 CAT(GRADE_UD(and,or),c1)
|
||||
B SORT_C1(B t, B x) {
|
||||
@ -140,9 +185,9 @@ B SORT_C1(B t, B x) {
|
||||
} else if (xe==el_i8) {
|
||||
i8* xp = i8any_ptr(x);
|
||||
i8* rp; r = m_i8arrv(&rp, n);
|
||||
if (n<16) {
|
||||
if (n < 16) {
|
||||
INSERTION_SORT(i8);
|
||||
} else if (n<256) {
|
||||
} else if (n < 256) {
|
||||
RADIX_SORT_i8(u8, SORT);
|
||||
} else {
|
||||
COUNTING_SORT(i8);
|
||||
@ -150,7 +195,7 @@ B SORT_C1(B t, B x) {
|
||||
} else if (xe==el_i16) {
|
||||
i16* xp = i16any_ptr(x);
|
||||
i16* rp; r = m_i16arrv(&rp, n);
|
||||
if (n < 24) {
|
||||
if (n < 20) {
|
||||
INSERTION_SORT(i16);
|
||||
} else if (n < 256) {
|
||||
RADIX_SORT_i16(u8, SORT,);
|
||||
@ -162,7 +207,7 @@ B SORT_C1(B t, B x) {
|
||||
} else if (xe==el_i32) {
|
||||
i32* xp = i32any_ptr(x);
|
||||
i32* rp; r = m_i32arrv(&rp, n);
|
||||
if (n < 40) {
|
||||
if (n < 32) {
|
||||
INSERTION_SORT(i32);
|
||||
} else if (n < 256) {
|
||||
RADIX_SORT_i32(u8, SORT,);
|
||||
@ -181,6 +226,10 @@ B SORT_C1(B t, B x) {
|
||||
#undef SORT_C1
|
||||
#undef INSERTION_SORT
|
||||
#undef COUNTING_SORT
|
||||
#if SINGELI
|
||||
#undef WRITE_SPARSE_i8
|
||||
#undef WRITE_SPARSE_i16
|
||||
#endif
|
||||
|
||||
|
||||
#define GRADE_CHR GRADE_UD("⍋","⍒")
|
||||
@ -356,21 +405,23 @@ done:
|
||||
#undef LT
|
||||
#undef FOR
|
||||
#undef PRE
|
||||
#undef PRE_UD
|
||||
#undef INC
|
||||
#undef ROFF
|
||||
#undef PRE64
|
||||
#undef CHOOSE_SG_SORT
|
||||
#undef CHOOSE_SG_GRADE
|
||||
#undef RADIX_SORT_i8
|
||||
#undef RADIX_SORT_i16
|
||||
#undef RADIX_SORT_i32
|
||||
#undef RADIX_SUM_1
|
||||
#undef RADIX_SUM_2
|
||||
#undef RADIX_SUM_4
|
||||
#undef RADIX_SUM_1_u8
|
||||
#undef RADIX_SUM_1_usz
|
||||
#undef RADIX_SORT_i16
|
||||
#undef RADIX_SUM_2_u8
|
||||
#undef RADIX_SUM_2
|
||||
#undef RADIX_SUM_2_usz
|
||||
#undef RADIX_SUM_2_u32
|
||||
#undef RADIX_SORT_i32
|
||||
#undef RADIX_SUM_4_u8
|
||||
#undef RADIX_SUM_4
|
||||
#undef RADIX_SUM_4_usz
|
||||
#undef RADIX_SUM_4_u32
|
||||
#undef GRADE_CAT
|
||||
|
||||
@ -104,6 +104,13 @@ static B truncReshape(B x, usz xia, usz nia, ur nr, ShArr* sh) { // consumes all
|
||||
arr_shSetU(ra, nr, sh);
|
||||
return r;
|
||||
}
|
||||
static void fill_words(void* rp, u64 v, u64 bytes) {
|
||||
usz wds = bytes/8;
|
||||
usz ext = bytes%8;
|
||||
u64* p = rp;
|
||||
for (usz i=0; i<wds; i++) p[i] = v;
|
||||
if (ext) memcpy(p+wds, &v, ext);
|
||||
}
|
||||
B shape_c2(B t, B w, B x) {
|
||||
usz xia = isArr(x)? IA(x) : 1;
|
||||
usz nia = 1;
|
||||
@ -181,63 +188,121 @@ B shape_c2(B t, B w, B x) {
|
||||
decG(w);
|
||||
}
|
||||
|
||||
B xf;
|
||||
if (isAtm(x)) {
|
||||
xf = asFill(inc(x));
|
||||
// goes to unit
|
||||
} else {
|
||||
Arr* r;
|
||||
if (isArr(x)) {
|
||||
if (nia <= xia) {
|
||||
return truncReshape(x, xia, nia, nr, sh);
|
||||
} else {
|
||||
xf = getFillQ(x);
|
||||
if (xia<=1) {
|
||||
if (RARE(xia==0)) {
|
||||
thrM("⥊: Empty 𝕩 and non-empty result");
|
||||
// if (noFill(xf)) thrM("⥊: No fill for empty array");
|
||||
// dec(x);
|
||||
// x = inc(xf);
|
||||
} else {
|
||||
B n = IGet(x,0);
|
||||
decG(x);
|
||||
x = n;
|
||||
}
|
||||
if (xia <= 1) {
|
||||
if (RARE(xia == 0)) thrM("⥊: Empty 𝕩 and non-empty result");
|
||||
B n = IGet(x,0);
|
||||
decG(x);
|
||||
x = n;
|
||||
goto unit;
|
||||
}
|
||||
if (xia <= nia/2) x = any_squeeze(x);
|
||||
|
||||
MAKE_MUT(m, nia); mut_init(m, TI(x,elType));
|
||||
MUTG_INIT(m);
|
||||
i64 div = nia/xia;
|
||||
i64 mod = nia%xia;
|
||||
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
|
||||
mut_copyG(m, div*xia, x, 0, mod);
|
||||
u8 xl = arrTypeBitsLog(TY(x));
|
||||
u8 xt = arrNewType(TY(x));
|
||||
u8* rp;
|
||||
u64 bi, bf; // Bytes present, bytes wanted
|
||||
if (xl == 0) { // Bits
|
||||
u64* rq; r = m_bitarrp(&rq, nia);
|
||||
rp = (u8*)rq;
|
||||
usz nw = BIT_N(nia);
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64 b = xia;
|
||||
if (b % 8) {
|
||||
if (b < 64) {
|
||||
// Need to avoid calling bit_cpy with arguments <64 bits apart
|
||||
u64 v = xp[0] & (~(u64)0 >> (64-b));
|
||||
do { v |= v<<b; b*=2; } while (b%8 && b<64);
|
||||
rq[0] = v;
|
||||
if (b>64 && nia>64) rq[1] = v>>(64-b/2);
|
||||
} else {
|
||||
memcpy(rq, xp, (b+7)/8);
|
||||
}
|
||||
for (; b%8; b*=2) {
|
||||
if (b>nw*32) {
|
||||
if (b<nia) bit_cpy(rq, b, rq, 0, nia-b);
|
||||
b = 64*nw; // Ensure bi>=bf since bf is rounded up
|
||||
break;
|
||||
}
|
||||
bit_cpy(rq, b, rq, 0, b);
|
||||
}
|
||||
} else {
|
||||
memcpy(rp, xp, b/8);
|
||||
}
|
||||
bi = b/8;
|
||||
bf = 8*nw;
|
||||
if (bi == 1) { memset(rp, rp[0], bf); bi=bf; }
|
||||
} else {
|
||||
if (TI(x,elType) == el_B) {
|
||||
B xf = getFillQ(x);
|
||||
MAKE_MUT(m, nia); mut_init(m, el_B);
|
||||
MUTG_INIT(m);
|
||||
i64 div = nia/xia;
|
||||
i64 mod = nia%xia;
|
||||
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
|
||||
mut_copyG(m, div*xia, x, 0, mod);
|
||||
decG(x);
|
||||
Arr* ra = mut_fp(m);
|
||||
arr_shSetU(ra, nr, sh);
|
||||
return withFill(taga(ra), xf);
|
||||
}
|
||||
u8 xk = xl - 3;
|
||||
rp = m_tyarrp(&r, 1<<xk, nia, xt);
|
||||
bi = (u64)xia<<xk;
|
||||
bf = (u64)nia<<xk;
|
||||
memcpy(rp, tyany_ptr(x), bi);
|
||||
}
|
||||
decG(x);
|
||||
Arr* ra = mut_fp(m);
|
||||
arr_shSetU(ra, nr, sh);
|
||||
return withFill(taga(ra), xf);
|
||||
if (bi<=8 && !(bi & (bi-1))) {
|
||||
// Divisor of 8: write words
|
||||
usz b = bi*8;
|
||||
u64 v = *(u64*)rp & (~(u64)0 >> (64-b));
|
||||
while (b<64) { v |= v<<b; b*=2; }
|
||||
fill_words(rp, v, bf);
|
||||
} else {
|
||||
// Double up to length l, then copy in blocks
|
||||
u64 l = 1<<15; if (l>bf) l=bf;
|
||||
for (; bi<=l/2; bi+=bi) memcpy(rp+bi, rp, bi);
|
||||
u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi);
|
||||
if (e<bf) memcpy(rp+e, rp, bf-e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unit:
|
||||
#define FILL(E,T,V) T* rp; r = m_##E##arrp(&rp,nia); fill_words(rp, V, (u64)nia*sizeof(T));
|
||||
if (isF64(x)) {
|
||||
i32 n = (i32)x.f;
|
||||
if (RARE(n!=x.f)) {
|
||||
FILL(f64,f64,x.u)
|
||||
} else if (n==(i8)n) { // memset can be faster than writing words
|
||||
u8 b = n;
|
||||
i8* rp; u64 nb = nia;
|
||||
if (b <= 1) { r = m_bitarrp((u64**)&rp,nia); nb = 8*BIT_N(nia); b=-b; }
|
||||
else { r = m_i8arrp ( &rp,nia); }
|
||||
memset(rp, b, nb);
|
||||
} else {
|
||||
if(n==(i16)n) { FILL(i16,i16,(u16)n*0x0001000100010001) }
|
||||
else { FILL(i32,i32,(u32)n*0x0000000100000001) }
|
||||
}
|
||||
} else if (isC32(x)) {
|
||||
u32 c = o2cG(x);
|
||||
if (c==(u8 )c) { u8* rp; r = m_c8arrp(&rp,nia); memset(rp, c, nia); }
|
||||
else if (c==(u16)c) { FILL(c16,u16,c*0x0001000100010001) }
|
||||
else { FILL(c32,u32,c*0x0000000100000001) }
|
||||
} else {
|
||||
incBy(x, nia); // in addition with the existing reference, this covers the filled amount & asFill
|
||||
r = m_fillarrp(nia);
|
||||
if (sizeof(B)==8) fill_words(fillarr_ptr(r), x.u, (u64)nia*8);
|
||||
else for (usz i = 0; i < nia; i++) fillarr_ptr(r)[i] = x;
|
||||
fillarr_setFill(r, asFill(x));
|
||||
}
|
||||
#undef FILL
|
||||
}
|
||||
|
||||
unit:
|
||||
if (isF64(x)) { decA(xf);
|
||||
i32 n = (i32)x.f;
|
||||
if (RARE(n!=x.f)) { f64* rp; Arr* r = m_f64arrp(&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=x.f; return taga(r); }
|
||||
else if(n==(n&1)) { Arr* r=n?allOnes(nia):allZeroes(nia); arr_shSetU(r,nr,sh); return taga(r); }
|
||||
else if(n==(i8 )n) { i8* rp; Arr* r = m_i8arrp (&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=n ; return taga(r); }
|
||||
else if(n==(i16)n) { i16* rp; Arr* r = m_i16arrp(&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=n ; return taga(r); }
|
||||
else { i32* rp; Arr* r = m_i32arrp(&rp,nia); arr_shSetU(r,nr,sh); for (u64 i=0; i<nia; i++) rp[i]=n ; return taga(r); }
|
||||
}
|
||||
if (isC32(x)) { decA(xf);
|
||||
u32* rp; Arr* r = m_c32arrp(&rp, nia); arr_shSetU(r, nr, sh);
|
||||
u32 c = o2cG(x);
|
||||
for (u64 i = 0; i < nia; i++) rp[i] = c;
|
||||
return taga(r);
|
||||
}
|
||||
Arr* r = m_fillarrp(nia); arr_shSetU(r, nr, sh);
|
||||
B* rp = fillarr_ptr(r);
|
||||
if (nia) incBy(x, nia-1);
|
||||
else dec(x);
|
||||
for (u64 i = 0; i < nia; i++) rp[i] = x;
|
||||
fillarr_setFill(r, xf);
|
||||
arr_shSetU(r,nr,sh);
|
||||
return taga(r);
|
||||
}
|
||||
|
||||
|
||||
@ -104,6 +104,17 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if SINGELI
|
||||
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
|
||||
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
|
||||
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
|
||||
#define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
|
||||
#define PLUS_SCAN(T) avx2_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
|
||||
extern void (*const avx2_scan_max32)(int32_t* v0,int32_t* v1,uint64_t v2);
|
||||
#else
|
||||
#define PLUS_SCAN(T) for (usz i=k; i<e; i++) js=rp[i]+=js;
|
||||
#endif
|
||||
|
||||
// Dense Where, still significantly worse than SIMD
|
||||
// Assumes modifiable DST
|
||||
#define WHERE_DENSE(SRC, DST, LEN, OFF) do { \
|
||||
@ -437,6 +448,62 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
|
||||
return r;
|
||||
}
|
||||
|
||||
// Replicate using plus/max/xor-scan
|
||||
#define SCAN_CORE(WV, UPD, SET, SCAN) \
|
||||
usz b = 1<<10; \
|
||||
for (usz k=0, j=0, ij=WV; ; ) { \
|
||||
usz e = b<s-k? k+b : s; \
|
||||
SET; for (usz i=k; i<e; i++) rp[i]=0; \
|
||||
while (ij<e) { j++; UPD; ij+=WV; } \
|
||||
SCAN; \
|
||||
if (e==s) {break;} k=e; \
|
||||
}
|
||||
#define SUM_CORE(T, WV, PREP, INC) \
|
||||
SCAN_CORE(WV, PREP; rp[ij]+=INC, , PLUS_SCAN(T))
|
||||
|
||||
#if SINGELI
|
||||
#define IND_BY_SCAN \
|
||||
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, avx2_scan_max32(rp+k,rp+k,e-k))
|
||||
#else
|
||||
#define IND_BY_SCAN usz js=0; SUM_CORE(i32, xp[j], , 1)
|
||||
#endif
|
||||
|
||||
#define REP_BY_SCAN(T, WV) \
|
||||
T* xp = xv; T* rp = rv; \
|
||||
T js=xp[0], px=js; \
|
||||
SUM_CORE(T, WV, T sx=px, (px=xp[j])-sx)
|
||||
|
||||
#define BOOL_REP_XOR_SCAN(WV) \
|
||||
usz b = 1<<12; \
|
||||
u64 xx=xp[0], xs=xx>>63, js=-(xx&1); xx^=xx<<1; \
|
||||
for (usz k=0, j=0, ij=WV; ; ) { \
|
||||
usz e = b<s-k? k+b : s; \
|
||||
usz eb = (e-1)/64+1; \
|
||||
for (usz i=k/64; i<eb; i++) rp[i]=0; \
|
||||
while (ij<e) { \
|
||||
xx>>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \
|
||||
rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=WV; \
|
||||
} \
|
||||
for (usz i=k/64; i<eb; i++) js=-((rp[i]^=js)>>63); \
|
||||
if (e==s) {break;} k=e; \
|
||||
}
|
||||
|
||||
// Basic boolean loop with overwriting
|
||||
#define BOOL_REP_OVER(WV, LEN) \
|
||||
u64 ri=0, rc=0, xc=0; usz j=0; \
|
||||
for (usz i = 0; i < LEN; i++) { \
|
||||
u64 v = -(u64)bitp_get(xp,i); \
|
||||
rc ^= (v^xc) << (ri%64); \
|
||||
xc = v; \
|
||||
ri += WV; usz e = ri/64; \
|
||||
if (j < e) { \
|
||||
rp[j++] = rc; \
|
||||
while (j < e) rp[j++] = v; \
|
||||
rc = v; \
|
||||
} \
|
||||
} \
|
||||
if (ri%64) rp[j] = rc;
|
||||
|
||||
extern B rt_slash;
|
||||
B slash_c1(B t, B x) {
|
||||
if (RARE(isAtm(x)) || RARE(RNK(x)!=1)) thrF("/: Argument must have rank 1 (%H ≡ ≢𝕩)", x);
|
||||
@ -465,17 +532,8 @@ B slash_c1(B t, B x) {
|
||||
for (u64 j = 0; j < c; j++) *rp++ = i;
|
||||
}
|
||||
} else {
|
||||
if (s/16 <= xia) { // Sparse case: type of x matters
|
||||
#define SPARSE_IND(T) \
|
||||
T* xp = T##any_ptr(x); \
|
||||
usz b = 1<<10; \
|
||||
for (usz k=0, j=0, js=0, ij=xp[0]; ; ) { \
|
||||
usz e = b<s-k? k+b : s; \
|
||||
for (usz i=k; i<e; i++) rp[i]=0; \
|
||||
while (ij<e) { rp[ij]++; ij+=xp[++j]; } \
|
||||
for (usz i=k; i<e; i++) js=rp[i]+=js; \
|
||||
if (e==s) {break;} k=e; \
|
||||
}
|
||||
if (s/32 <= xia) { // Sparse case: type of x matters
|
||||
#define SPARSE_IND(T) T* xp = T##any_ptr(x); IND_BY_SCAN
|
||||
i32* rp; r = m_i32arrv(&rp, s);
|
||||
if (xe == el_i8 ) { SPARSE_IND(i8 ); }
|
||||
else if (xe == el_i16) { SPARSE_IND(i16); }
|
||||
@ -502,18 +560,30 @@ B slash_c1(B t, B x) {
|
||||
}
|
||||
|
||||
B slash_c2(B t, B w, B x) {
|
||||
B r;
|
||||
if (isArr(w) && RNK(w)==1 && depth(w)==1) {
|
||||
usz wia = IA(w);
|
||||
i32 wv = -1;
|
||||
usz wia;
|
||||
if (isArr(w)) {
|
||||
if (depth(w)>1) goto base;
|
||||
ur wr = RNK(w);
|
||||
if (wr>1) thrF("/: Simple 𝕨 must have rank 0 or 1 (%i≡=𝕨)", wr);
|
||||
if (wr<1) { B v=IGet(w, 0); decG(w); w=v; goto atom; }
|
||||
wia = IA(w);
|
||||
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
|
||||
if (isAtm(x) || RNK(x)==0) thrM("/: 𝕩 must have rank at least 1 for simple 𝕨");
|
||||
ur xr = RNK(x);
|
||||
usz xlen = *SH(x);
|
||||
} else {
|
||||
atom:
|
||||
if (!q_i32(w)) goto base;
|
||||
wv = o2i(w);
|
||||
}
|
||||
if (isAtm(x) || RNK(x)==0) thrM("/: 𝕩 must have rank at least 1 for simple 𝕨");
|
||||
ur xr = RNK(x);
|
||||
usz xlen = *SH(x);
|
||||
u8 xl = cellWidthLog(x);
|
||||
u8 xt = arrNewType(TY(x));
|
||||
|
||||
B r;
|
||||
if (wv < 0) { // Array w
|
||||
if (RARE(wia!=xlen)) thrF("/: Lengths of components of 𝕨 must match 𝕩 (%s ≠ %s)", wia, xlen);
|
||||
|
||||
u8 xl = cellWidthLog(x);
|
||||
u8 xt = arrNewType(TY(x));
|
||||
|
||||
u8 we = TI(w,elType);
|
||||
if (!elInt(we)) {
|
||||
w=any_squeeze(w); we=TI(w,elType);
|
||||
@ -555,30 +625,16 @@ B slash_c2(B t, B w, B x) {
|
||||
// Make shape if needed; all cases below use it
|
||||
usz* rsh = NULL;
|
||||
if (xr > 1) {
|
||||
usz* sh = rsh = m_shArr(xr)->a;
|
||||
sh[0] = s;
|
||||
shcpy(sh+1, SH(x)+1, xr-1);
|
||||
rsh = m_shArr(xr)->a;
|
||||
rsh[0] = s;
|
||||
shcpy(rsh+1, SH(x)+1, xr-1);
|
||||
}
|
||||
|
||||
if (xl == 0) {
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64* rp; r = m_bitarrv(&rp, s); if (rsh) { SPRNK(a(r),xr); SH(r) = rsh; }
|
||||
if (s/256 <= wia) {
|
||||
#define SPARSE_REP(T) \
|
||||
T* wp = T##any_ptr(w); \
|
||||
usz b = 1<<12; \
|
||||
u64 xx=xp[0], xs=xx>>63, js=-(xx&1); xx^=xx<<1; \
|
||||
for (usz k=0, j=0, ij=wp[0]; ; ) { \
|
||||
usz e = b<s-k? k+b : s; \
|
||||
usz eb = (e-1)/64+1; \
|
||||
for (usz i=k/64; i<eb; i++) rp[i]=0; \
|
||||
while (ij<e) { \
|
||||
xx>>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \
|
||||
rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=wp[j]; \
|
||||
} \
|
||||
for (usz i=k/64; i<eb; i++) js=-((rp[i]^=js)>>63); \
|
||||
if (e==s) {break;} k=e; \
|
||||
}
|
||||
if (s/1024 <= wia) {
|
||||
#define SPARSE_REP(T) T* wp=T##any_ptr(w); BOOL_REP_XOR_SCAN(wp[j])
|
||||
if (we==el_i8 ) { SPARSE_REP(i8 ); }
|
||||
else if (we==el_i16) { SPARSE_REP(i16); }
|
||||
else { SPARSE_REP(i32); }
|
||||
@ -586,37 +642,15 @@ B slash_c2(B t, B w, B x) {
|
||||
} else {
|
||||
if (we < el_i32) w = taga(cpyI32Arr(w));
|
||||
i32* wp = i32any_ptr(w);
|
||||
u64 ri=0, rc=0, xc=0; usz j=0;
|
||||
for (usz i = 0; i < wia; i++) {
|
||||
u64 v = -(u64)bitp_get(xp,i);
|
||||
rc ^= (v^xc) << (ri%64);
|
||||
xc = v;
|
||||
ri += wp[i]; usz e = ri/64;
|
||||
if (j < e) {
|
||||
rp[j++] = rc;
|
||||
while (j < e) rp[j++] = v;
|
||||
rc = v;
|
||||
}
|
||||
}
|
||||
if (ri%64) rp[j] = rc;
|
||||
BOOL_REP_OVER(wp[i], wia)
|
||||
}
|
||||
} else {
|
||||
u8 xk = xl-3;
|
||||
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
|
||||
if (rsh) { Arr* ra=a(r); SPRNK(ra,xr); PSH(ra) = rsh; PIA(ra) = s*arr_csz(x); }
|
||||
void* xv = tyany_ptr(x);
|
||||
if (s/32 <= wia) { // Sparse case: use both types
|
||||
#define CASE(L,XT) case L: { \
|
||||
XT* xp = xv; XT* rp = rv; \
|
||||
usz b = 1<<10; \
|
||||
XT js=xp[0], px=js; \
|
||||
for (usz k=0, j=0, ij=wp[0]; ; ) { \
|
||||
usz e = b<s-k? k+b : s; \
|
||||
for (usz i=k; i<e; i++) rp[i]=0; \
|
||||
while (ij<e) { j++; XT sx=px; rp[ij]^=sx^(px=xp[j]); ij+=wp[j]; } \
|
||||
for (usz i=k; i<e; i++) js=rp[i]^=js; \
|
||||
if (e==s) {break;} k=e; \
|
||||
} break; }
|
||||
if ((xk<3? s/64 : s/32) <= wia) { // Sparse case: use both types
|
||||
#define CASE(L,XT) case L: { REP_BY_SCAN(XT, wp[j]) break; }
|
||||
#define SPARSE_REP(WT) \
|
||||
WT* wp = WT##any_ptr(w); \
|
||||
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
|
||||
@ -640,33 +674,47 @@ B slash_c2(B t, B w, B x) {
|
||||
}
|
||||
}
|
||||
goto decWX_ret;
|
||||
}
|
||||
if (isArr(x) && RNK(x)==1 && q_i32(w)) {
|
||||
usz xia = IA(x);
|
||||
i32 wv = o2i(w);
|
||||
if (wv<=0) {
|
||||
if (wv<0) thrM("/: 𝕨 cannot be negative");
|
||||
return taga(arr_shVec(TI(x,slice)(x, 0, 0)));
|
||||
} else {
|
||||
if (wv <= 1) {
|
||||
if (wv < 0) thrM("/: 𝕨 cannot be negative");
|
||||
return wv ? x : taga(arr_shVec(TI(x,slice)(x, 0, 0)));
|
||||
}
|
||||
if (TI(x,elType)==el_i32) {
|
||||
i32* xp = i32any_ptr(x);
|
||||
i32* rp; r = m_i32arrv(&rp, xia*wv);
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
for (i64 j = 0; j < wv; j++) *rp++ = xp[i];
|
||||
}
|
||||
goto decX_ret;
|
||||
} else {
|
||||
if (xlen == 0) return x;
|
||||
usz s = xlen * wv;
|
||||
if (xl>6 || (xl<3 && xl!=0) || TI(x,elType)==el_B) {
|
||||
if (xr != 1) goto base;
|
||||
SLOW2("𝕨/𝕩", w, x);
|
||||
B xf = getFillQ(x);
|
||||
HArr_p r0 = m_harrUv(xia*wv);
|
||||
HArr_p r0 = m_harrUv(s);
|
||||
SGetU(x)
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
for (usz i = 0; i < xlen; i++) {
|
||||
B cx = incBy(GetU(x, i), wv);
|
||||
for (i64 j = 0; j < wv; j++) *r0.a++ = cx;
|
||||
}
|
||||
r = withFill(r0.b, xf);
|
||||
goto decX_ret;
|
||||
}
|
||||
if (xl == 0) {
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64* rp; r = m_bitarrv(&rp, s);
|
||||
if (wv <= 256) { BOOL_REP_XOR_SCAN(wv) }
|
||||
else { BOOL_REP_OVER(wv, xlen) }
|
||||
goto decX_ret;
|
||||
} else {
|
||||
u8 xk = xl-3;
|
||||
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
|
||||
void* xv = tyany_ptr(x);
|
||||
#define CASE(L,T) case L: { REP_BY_SCAN(T, wv) break; }
|
||||
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
|
||||
#undef CASE
|
||||
}
|
||||
if (xr > 1) {
|
||||
usz* rsh = m_shArr(xr)->a;
|
||||
rsh[0] = s;
|
||||
shcpy(rsh+1, SH(x)+1, xr-1);
|
||||
Arr* ra=a(r); SPRNK(ra,xr); PSH(ra)=rsh; PIA(ra)=s*arr_csz(x);
|
||||
}
|
||||
goto decX_ret;
|
||||
}
|
||||
base:
|
||||
return c2(rt_slash, w, x);
|
||||
|
||||
@ -10,20 +10,42 @@ def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
|
||||
def base{b,l} = { if (0==tuplen{l}) 0; else tupsel{0,l}+b*base{b,slice{l,1}} }
|
||||
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
|
||||
|
||||
# Fill last 4 bytes with last element, in each lane
|
||||
def spread{a:VT} = {
|
||||
def w = width{eltype{VT}}
|
||||
def b = w/8
|
||||
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
||||
}
|
||||
|
||||
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
|
||||
def step = 256/width{T}
|
||||
def V = [step]T
|
||||
p:= broadcast{V, init}
|
||||
xv:= *V ~~ x
|
||||
rv:= *V ~~ r
|
||||
e:= len/step
|
||||
@for (xv, rv over e) rv = scan{xv,p}
|
||||
q:= len & (step-1)
|
||||
if (q) maskstoreF{rv, maskOf{V, q}, e, scan_last{load{xv,e}, p}}
|
||||
}
|
||||
def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
|
||||
def last{v, p} = op{pre{v}, p}
|
||||
def scan{v, p} = {
|
||||
n:= last{v, p}
|
||||
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
|
||||
n
|
||||
}
|
||||
scan_loop{T, init, x, r, len, scan, last}
|
||||
}
|
||||
|
||||
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
|
||||
avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
|
||||
def w = width{T}
|
||||
|
||||
# Within each lane, scan using shifts by powers of 2. First k elements
|
||||
# when shifting by k don't need to change, so leave them alone.
|
||||
def w = width{T}
|
||||
def shift{k,l} = merge{iota{k},iota{l-k}}
|
||||
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
|
||||
# Fill last 4 bytes with last element, in each lane
|
||||
def spread{a} = {
|
||||
def b = w/8
|
||||
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
||||
}
|
||||
# Prefix op on entire AVX register
|
||||
def pre{a} = {
|
||||
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
|
||||
@ -31,19 +53,7 @@ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
|
||||
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
|
||||
}
|
||||
|
||||
def step = 256/w
|
||||
def V = [step]T
|
||||
p:= broadcast{V, id}
|
||||
xv:= *V ~~ x
|
||||
rv:= *V ~~ r
|
||||
e:= len/step
|
||||
@for (xv, rv over e) {
|
||||
n:= op{pre{xv}, p}
|
||||
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
|
||||
rv = n
|
||||
}
|
||||
q:= len & (step-1)
|
||||
if (q) maskstoreF{rv, maskOf{V, q}, e, op{pre{load{xv,e}}, p}}
|
||||
scan_post{T, id, x, r, len, op, pre}
|
||||
}
|
||||
def avx2_scan_idem{T, op} = {
|
||||
def m = 1 << (width{T}-1)
|
||||
@ -56,6 +66,25 @@ def avx2_scan_idem{T, op} = {
|
||||
'avx2_scan_min32' = avx2_scan_idem{i32, min}
|
||||
'avx2_scan_max32' = avx2_scan_idem{i32, max}
|
||||
|
||||
# Associative scan
|
||||
avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
|
||||
# Prefix op on entire AVX register
|
||||
def pre{a} = {
|
||||
# Within each lane, scan using shifts by powers of 2.
|
||||
# Assumes identity is 0.
|
||||
def w = width{T}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
|
||||
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
|
||||
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
||||
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
|
||||
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
|
||||
}
|
||||
scan_post{T, init, x, r, len, op, pre}
|
||||
}
|
||||
'avx2_scan_pluswrap_u8' = avx2_scan_assoc_0{u8 , +}
|
||||
'avx2_scan_pluswrap_u16' = avx2_scan_assoc_0{u16, +}
|
||||
'avx2_scan_pluswrap_u32' = avx2_scan_assoc_0{u32, +}
|
||||
|
||||
# Boolean cumulative sum
|
||||
avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
|
||||
rv:= *[8]u32~~r
|
||||
|
||||
Loading…
Reference in New Issue
Block a user