Fast array reshape

This commit is contained in:
Marshall Lochbaum 2022-09-23 16:57:41 -04:00
parent 218a8b069e
commit e8e10790f8

View File

@ -188,6 +188,7 @@ B shape_c2(B t, B w, B x) {
decG(w);
}
Arr* r;
if (isArr(x)) {
if (nia <= xia) {
return truncReshape(x, xia, nia, nr, sh);
@ -201,21 +202,77 @@ B shape_c2(B t, B w, B x) {
}
if (xia <= nia/2) x = any_squeeze(x);
B xf = getFillQ(x);
MAKE_MUT(m, nia); mut_init(m, TI(x,elType));
MUTG_INIT(m);
i64 div = nia/xia;
i64 mod = nia%xia;
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
mut_copyG(m, div*xia, x, 0, mod);
u8 xl = arrTypeBitsLog(TY(x));
u8 xt = arrNewType(TY(x));
u8* rp;
u64 bi, bf; // Bytes present, bytes wanted
if (xl == 0) { // Bits
u64* rq; r = m_bitarrp(&rq, nia);
rp = (u8*)rq;
usz nw = BIT_N(nia);
u64* xp = bitarr_ptr(x);
u64 b = xia;
if (b % 8) {
if (b < 64) {
// Need to avoid calling bit_cpy with arguments <64 bits apart
u64 v = xp[0] & (~(u64)0 >> (64-b));
do { v |= v<<b; b*=2; } while (b%8 && b<64);
rq[0] = v;
if (b>64 && nia>64) rq[1] = v>>(64-b/2);
} else {
memcpy(rq, xp, (b+7)/8);
}
for (; b%8; b*=2) {
if (b>nw*32) {
if (b<nia) bit_cpy(rq, b, rq, 0, nia-b);
b = 64*nw; // Ensure bi>=bf since bf is rounded up
break;
}
bit_cpy(rq, b, rq, 0, b);
}
} else {
memcpy(rp, xp, b/8);
}
bi = b/8;
bf = 8*nw;
if (bi == 1) { memset(rp, rp[0], bf); bi=bf; }
} else {
if (TI(x,elType) == el_B) {
B xf = getFillQ(x);
MAKE_MUT(m, nia); mut_init(m, el_B);
MUTG_INIT(m);
i64 div = nia/xia;
i64 mod = nia%xia;
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
mut_copyG(m, div*xia, x, 0, mod);
decG(x);
Arr* ra = mut_fp(m);
arr_shSetU(ra, nr, sh);
return withFill(taga(ra), xf);
}
u8 xk = xl - 3;
rp = m_tyarrp(&r, 1<<xk, nia, xt);
bi = (u64)xia<<xk;
bf = (u64)nia<<xk;
memcpy(rp, tyany_ptr(x), bi);
}
decG(x);
Arr* ra = mut_fp(m);
arr_shSetU(ra, nr, sh);
return withFill(taga(ra), xf);
if (bi<=8 && !(bi & (bi-1))) {
// Divisor of 8: write words
usz b = bi*8;
u64 v = *(u64*)rp & (~(u64)0 >> (64-b));
while (b<64) { v |= v<<b; b*=2; }
fill_words(rp, v, bf);
} else {
// Double up to length l, then copy in blocks
u64 l = 1<<15; if (l>bf) l=bf;
for (; bi<=l/2; bi+=bi) memcpy(rp+bi, rp, bi);
u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi);
if (e<bf) memcpy(rp+e, rp, bf-e);
}
}
} else {
unit:
Arr* r;
#define FILL(E,T,V) T* rp; r = m_##E##arrp(&rp,nia); fill_words(rp, V, (u64)nia*sizeof(T));
if (isF64(x)) {
i32 n = (i32)x.f;
@ -245,9 +302,9 @@ B shape_c2(B t, B w, B x) {
fill_words(fillarr_ptr(r), x.u, (u64)nia*8);
fillarr_setFill(r, xf);
}
arr_shSetU(r,nr,sh);
return taga(r);
}
arr_shSetU(r,nr,sh);
return taga(r);
}
B pick_c1(B t, B x) {