Combine transp_c2 temp allocations into one TALLOC
This commit is contained in:
parent
5b17ee44d6
commit
7fea4ca2ad
@ -166,7 +166,7 @@ B transp_c2(B t, B w, B x) {
|
|||||||
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
|
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
|
||||||
|
|
||||||
// Axis permutation
|
// Axis permutation
|
||||||
TALLOC(ur, p, xr);
|
TALLOC(ur, p, xr*(1+3*sizeof(usz))); // Also rsh, st, ri
|
||||||
if (isAtm(w)) {
|
if (isAtm(w)) {
|
||||||
usz a=o2s(w);
|
usz a=o2s(w);
|
||||||
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
|
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
|
||||||
@ -184,7 +184,7 @@ B transp_c2(B t, B w, B x) {
|
|||||||
|
|
||||||
// compute shape for the given axes
|
// compute shape for the given axes
|
||||||
usz* xsh = SH(x);
|
usz* xsh = SH(x);
|
||||||
TALLOC(usz, rsh, xr);
|
usz *rsh = (usz*)(p + xr); // Length xr
|
||||||
usz dup = 0, max = 0;
|
usz dup = 0, max = 0;
|
||||||
usz no_sh = -(usz)1;
|
usz no_sh = -(usz)1;
|
||||||
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
|
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
|
||||||
@ -221,23 +221,22 @@ B transp_c2(B t, B w, B x) {
|
|||||||
if (ar <= 1) { r = x; goto ret; }
|
if (ar <= 1) { r = x; goto ret; }
|
||||||
ur na = ar - dup;
|
ur na = ar - dup;
|
||||||
// Add up stride for each axis
|
// Add up stride for each axis
|
||||||
TALLOC(u64, st, rr);
|
usz* st = rsh + xr; // Length rr
|
||||||
for (usz j=0; j<rr; j++) st[j] = 0;
|
for (usz j=0; j<rr; j++) st[j] = 0;
|
||||||
usz c = 1;
|
usz c = 1;
|
||||||
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
|
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
|
||||||
|
|
||||||
u8 xe = TI(x,elType);
|
u8 xe = TI(x,elType);
|
||||||
usz csz = shProd(xsh, ar, xr);
|
usz csz = shProd(xsh, ar, xr);
|
||||||
#define AXIS_LOOP(INIT, N_AX, I_INC, DO_INNER) \
|
#define AXIS_LOOP(N_AX, I_INC, DO_INNER) \
|
||||||
ur a0 = N_AX - 1; \
|
ur a0 = N_AX - 1; \
|
||||||
for (usz i=0; i<na; i++) st[i] *= csz; \
|
for (usz i=0; i<na; i++) st[i] *= csz; \
|
||||||
TALLOC(usz, ri, a0); for (usz i=0; i<a0; i++) ri[i]=0; \
|
usz* ri = st+rr; for (usz i=0; i<a0; i++) ri[i]=0; \
|
||||||
INIT \
|
usz l = rsh[a0]; \
|
||||||
for (usz i=0, j0=0;;) { \
|
for (usz i=0, j0=0;;) { \
|
||||||
/* Hardcode one innermost loop: assume N_AX>=1 */ \
|
/* Hardcode one innermost loop: assume N_AX>=1 */ \
|
||||||
ur a = a0; \
|
ur a = a0; \
|
||||||
usz str = st[a]; \
|
usz str = st[a]; \
|
||||||
usz l = rsh[a]; \
|
|
||||||
for (usz k=0; k<l; k++) { \
|
for (usz k=0; k<l; k++) { \
|
||||||
usz j=j0+k*str; DO_INNER; \
|
usz j=j0+k*str; DO_INNER; \
|
||||||
i += I_INC; \
|
i += I_INC; \
|
||||||
@ -251,17 +250,16 @@ B transp_c2(B t, B w, B x) {
|
|||||||
ri[a] = 0; \
|
ri[a] = 0; \
|
||||||
j0 -= rsh[a] * str; \
|
j0 -= rsh[a] * str; \
|
||||||
} \
|
} \
|
||||||
} \
|
}
|
||||||
TFREE(ri);
|
|
||||||
u8 xlw = elWidthLogBits(xe);
|
u8 xlw = elWidthLogBits(xe);
|
||||||
if (csz >= (32*8) >> xlw) { // cell >= 32 bytes
|
if (csz >= (32*8) >> xlw) { // cell >= 32 bytes
|
||||||
usz ria = csz * shProd(rsh, 0, na);
|
usz ria = csz * shProd(rsh, 0, na);
|
||||||
AXIS_LOOP(MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm); ,
|
MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm);
|
||||||
na, csz, mut_copyG(rm, i, x, j, csz));
|
AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz));
|
||||||
Arr* ra = mut_fp(rm);
|
Arr* ra = mut_fp(rm);
|
||||||
shSet_alloc(ra, rr, rsh);
|
shSet_alloc(ra, rr, rsh);
|
||||||
r = withFill(taga(ra), getFillQ(x));
|
r = withFill(taga(ra), getFillQ(x));
|
||||||
decG(x); goto ret_decst;
|
decG(x); goto ret;
|
||||||
}
|
}
|
||||||
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
|
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
|
||||||
&& !dup && na>=2 && p[na-1]==ar-2 && p[na-2]==ar-1 // Last two axes transposed
|
&& !dup && na>=2 && p[na-1]==ar-2 && p[na-2]==ar-1 // Last two axes transposed
|
||||||
@ -280,11 +278,11 @@ B transp_c2(B t, B w, B x) {
|
|||||||
csz = (csz<<xlw) / 8; // Convert to bytes
|
csz = (csz<<xlw) / 8; // Convert to bytes
|
||||||
usz ria = rf*csz;
|
usz ria = rf*csz;
|
||||||
usz hw = h*w*csz;
|
usz hw = h*w*csz;
|
||||||
AXIS_LOOP(, na-2, hw, tran(rp+i, xp+j, w, h));
|
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h));
|
||||||
}
|
}
|
||||||
shSet_alloc(ra, rr, rsh);
|
shSet_alloc(ra, rr, rsh);
|
||||||
r = taga(ra);
|
r = taga(ra);
|
||||||
decG(x); goto ret_decst;
|
decG(x); goto ret;
|
||||||
}
|
}
|
||||||
#undef AXIS_LOOP
|
#undef AXIS_LOOP
|
||||||
|
|
||||||
@ -311,10 +309,7 @@ B transp_c2(B t, B w, B x) {
|
|||||||
}
|
}
|
||||||
r = C2(select, ind, x);
|
r = C2(select, ind, x);
|
||||||
|
|
||||||
ret_decst:;
|
|
||||||
TFREE(st);
|
|
||||||
ret:;
|
ret:;
|
||||||
TFREE(rsh);
|
|
||||||
TFREE(p);
|
TFREE(p);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user