Apply fast Replicate code whenever the cell size is right

This commit is contained in:
Marshall Lochbaum 2022-09-17 20:36:43 -04:00
parent ea1367e639
commit cae65947cd

View File

@ -261,13 +261,19 @@ static B where(B x, usz xia, u64 s) {
return r;
}
static B compress(B w, B x, usz wia) {
extern B take_c2(B, B, B);
static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
u64* wp = bitarr_ptr(w);
u64 we = 0;
usz ie = wia/64;
usz q=wia%64; if (q) we = wp[ie] &= ((u64)1<<q) - 1;
while (!we) {
if (RARE(ie==0)) { B xf = getFillQ(x); return q_N(xf)? emptyHVec() : isF64(xf)? emptyIVec() : isC32(xf)? emptyCVec() : m_emptyFVec(xf); }
if (RARE(ie==0)) {
if (RNK(x)>1) return take_c2(m_f64(0), m_f64(0), inc(x));
u8 xe = TI(x,elType);
if (xe != el_B) return elNum(xe)? emptyIVec() : emptyCVec();
B xf = getFillQ(x); return q_N(xf)? emptyHVec() : m_emptyFVec(xf);
}
we = wp[--ie];
}
usz wia0 = wia;
@ -276,8 +282,6 @@ static B compress(B w, B x, usz wia) {
if (wsum == wia0) return inc(x);
B r;
u8 xl = arrTypeBitsLog(TY(x));
u8 xt = arrNewType(TY(x));
switch(xl) { default: UD;
case 0: {
u64* xp = bitarr_ptr(x); u64* rp;
@ -330,7 +334,7 @@ static B compress(B w, B x, usz wia) {
#endif
#undef WITH_SPARSE
case 5: { i32* xp= tyany_ptr(x); i32* rp=m_tyarrv(&r,4,wsum,xt); COMPRESS_BLOCK(i32); break; }
case 6: if (TI(x,elType)!=el_B) { f64* xp=f64any_ptr(x); f64* rp; r = m_f64arrv(&rp,wsum); COMPRESS_BLOCK(f64); break; }
case 6: if (TI(x,elType)!=el_B) { u64* xp=tyany_ptr(x); u64* rp=m_tyarrv(&r,8,wsum,xt); COMPRESS_BLOCK(u64); break; }
else {
B xf = getFillQ(x);
B* xp = arr_bptr(x);
@ -349,6 +353,13 @@ static B compress(B w, B x, usz wia) {
}
#undef COMPRESS_BLOCK
}
ur xr = RNK(x);
if (xr > 1) {
Arr* ra=a(r); SPRNK(ra,xr);
usz* sh = PSH(ra) = m_shArr(xr)->a;
sh[0] = PIA(ra); PIA(ra) *= arr_csz(x);
shcpy(sh+1, SH(x)+1, xr-1);
}
return r;
}
@ -417,19 +428,23 @@ B slash_c1(B t, B x) {
}
B slash_c2(B t, B w, B x) {
if (isArr(x) && RNK(x)==1 && isArr(w) && RNK(w)==1 && depth(w)==1) {
if (isArr(w) && RNK(w)==1 && depth(w)==1) {
usz wia = IA(w);
usz xia = IA(x);
if (RARE(wia!=xia)) {
if (wia==0) { decG(w); return x; }
thrF("/: Lengths of components of 𝕨 must match 𝕩 (%s ≠ %s)", wia, xia);
}
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
if (isAtm(x) || RNK(x)==0) thrM("/: 𝕩 must have rank at least 1 for simple 𝕨");
ur xr = RNK(x);
usz xlen = *SH(x);
if (RARE(wia!=xlen)) thrF("/: Lengths of components of 𝕨 must match 𝕩 (%s ≠ %s)", wia, xlen);
u8 xl = cellWidthLog(x);
u8 xt = arrNewType(TY(x));
if (xl > 6 || (xl < 3 && xl != 0)) goto base;
u8 we = TI(w,elType);
if (we > el_i32) { w = any_squeeze(w); we = TI(w,elType); }
if (we==el_bit) {
wbool:
B r = compress(w, x, wia);
B r = compress(w, x, wia, xl, xt);
decG(w); decG(x); return r;
}
u64 s = usum(w);
@ -437,9 +452,9 @@ B slash_c2(B t, B w, B x) {
w = num_squeeze(w); we = TI(w,elType);
if (we==el_bit) goto wbool;
}
B r;
u8 xe = TI(x,elType);
if (RARE(we>el_i32 || xe==el_B)) { // Slow case
if (RARE(we>el_i32 || TI(x,elType)==el_B)) { // Slow case
if (xr > 1) goto base;
SLOW2("𝕨/𝕩", w, x);
B xf = getFillQ(x);
u64 ria = usum(w);
@ -454,9 +469,20 @@ B slash_c2(B t, B w, B x) {
}
decG(w); decG(x);
return withFill(HARR_FV(r), xf);
} else if (xe == el_bit) {
}
B r;
// Make shape if needed; all cases below use it
usz* rsh = NULL;
if (xr > 1) {
usz* sh = rsh = m_shArr(xr)->a;
sh[0] = s;
shcpy(sh+1, SH(x)+1, xr-1);
}
if (xl == 0) {
u64* xp = bitarr_ptr(x);
u64* rp; r = m_bitarrv(&rp, s);
u64* rp; r = m_bitarrv(&rp, s); if (rsh) { SPRNK(a(r),xr); SH(r) = rsh; }
if (s/256 <= wia) {
#define SPARSE_REP(T) \
T* wp = T##any_ptr(w); \
@ -495,9 +521,9 @@ B slash_c2(B t, B w, B x) {
if (ri%64) rp[j] = rc;
}
} else {
u8 xt = TY(x);
u8 xl = arrTypeBitsLog(xt)-3;
void* rv = m_tyarrv(&r, 1<<xl, s, arrNewType(xt));
u8 xk = xl-3;
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
if (rsh) { Arr* ra=a(r); SPRNK(ra,xr); PSH(ra) = rsh; PIA(ra) = s*arr_csz(x); }
void* xv = tyany_ptr(x);
if (s/32 <= wia) { // Sparse case: use both types
#define CASE(L,XT) case L: { \
@ -513,7 +539,7 @@ B slash_c2(B t, B w, B x) {
} break; }
#define SPARSE_REP(WT) \
WT* wp = WT##any_ptr(w); \
switch (xl) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
if (we == el_i8 ) { SPARSE_REP(i8 ); }
else if (we == el_i16) { SPARSE_REP(i16); }
else { SPARSE_REP(i32); }
@ -529,7 +555,7 @@ B slash_c2(B t, B w, B x) {
if (we < el_i32) w = taga(cpyI32Arr(w));
i32* wp = i32any_ptr(w);
while (wia>0 && !wp[wia-1]) wia--;
switch (xl) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) }
#undef CASE
}
}
@ -563,6 +589,7 @@ B slash_c2(B t, B w, B x) {
return withFill(r.b, xf);
}
}
base:
return c2(rt_slash, w, x);
}