Save a shape object in Reorder Axes

This commit is contained in:
Marshall Lochbaum 2023-03-28 20:17:04 -04:00
parent 7fea4ca2ad
commit de18fb996b

View File

@ -145,14 +145,9 @@ B ud_c1(B,B);
B tbl_c2(Md1D*,B,B); B tbl_c2(Md1D*,B,B);
B select_c2(B,B,B); B select_c2(B,B,B);
static void shSet_alloc(Arr* ra, ur rr, usz* rsh) { static void shSet(Arr* ra, ur rr, ShArr* sh) {
if (RARE(rr <= 1)) { if (RARE(rr <= 1)) arr_shVec(ra);
arr_shVec(ra); else arr_shSetU(ra, rr, sh);
} else {
ShArr* sh=m_shArr(rr);
shcpy(sh->a, rsh, rr);
arr_shSetU(ra, rr, sh);
}
} }
B transp_c2(B t, B w, B x) { B transp_c2(B t, B w, B x) {
@ -182,7 +177,7 @@ B transp_c2(B t, B w, B x) {
decG(w); decG(w);
} }
// compute shape for the given axes // Compute shape for the given axes
usz* xsh = SH(x); usz* xsh = SH(x);
usz *rsh = (usz*)(p + xr); // Length xr usz *rsh = (usz*)(p + xr); // Length xr
usz dup = 0, max = 0; usz dup = 0, max = 0;
@ -196,7 +191,7 @@ B transp_c2(B t, B w, B x) {
if (xl<l) rsh[j]=xl; if (xl<l) rsh[j]=xl;
} }
// fill in remaining axes and check for missing ones // Fill in remaining axes and check for missing ones
ur rr = xr-dup; ur rr = xr-dup;
if (max >= rr) thrF("⍉: Skipped result axis"); if (max >= rr) thrF("⍉: Skipped result axis");
if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) { if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) {
@ -204,13 +199,19 @@ B transp_c2(B t, B w, B x) {
rsh[j] = xsh[i]; rsh[j] = xsh[i];
i++; i++;
} }
// Create shape object, saving unprocessed result shape
ShArr* sh;
if (LIKELY(rr > 1)) { // Not all duplicates
sh = m_shArr(rr);
shcpy(sh->a, rsh, rr);
}
B r; B r;
// Empty result // Empty result
if (IA(x) == 0) { if (IA(x) == 0) {
Arr* ra = m_fillarrpEmpty(getFillQ(x)); Arr* ra = m_fillarrpEmpty(getFillQ(x));
shSet_alloc(ra, rr, rsh); shSet(ra, rr, sh);
decG(x); decG(x);
r = taga(ra); goto ret; r = taga(ra); goto ret;
} }
@ -218,11 +219,11 @@ B transp_c2(B t, B w, B x) {
// Number of axes that move // Number of axes that move
ur ar = max+1+dup; ur ar = max+1+dup;
if (!dup) while (ar>1 && p[ar-1]==ar-1) ar--; // Unmoved trailing if (!dup) while (ar>1 && p[ar-1]==ar-1) ar--; // Unmoved trailing
if (ar <= 1) { r = x; goto ret; } if (ar <= 1) { if (rr>1) ptr_dec(sh); r = x; goto ret; }
ur na = ar - dup; ur na = ar - dup;
// Add up stride for each axis // Add up stride for each axis
usz* st = rsh + xr; // Length rr usz* st = rsh + xr; // Length ar
for (usz j=0; j<rr; j++) st[j] = 0; for (usz j=0; j<ar; j++) st[j] = 0;
usz c = 1; usz c = 1;
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; } for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
@ -231,7 +232,7 @@ B transp_c2(B t, B w, B x) {
#define AXIS_LOOP(N_AX, I_INC, DO_INNER) \ #define AXIS_LOOP(N_AX, I_INC, DO_INNER) \
ur a0 = N_AX - 1; \ ur a0 = N_AX - 1; \
for (usz i=0; i<na; i++) st[i] *= csz; \ for (usz i=0; i<na; i++) st[i] *= csz; \
usz* ri = st+rr; for (usz i=0; i<a0; i++) ri[i]=0; \ usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0; \
usz l = rsh[a0]; \ usz l = rsh[a0]; \
for (usz i=0, j0=0;;) { \ for (usz i=0, j0=0;;) { \
/* Hardcode one innermost loop: assume N_AX>=1 */ \ /* Hardcode one innermost loop: assume N_AX>=1 */ \
@ -257,12 +258,12 @@ B transp_c2(B t, B w, B x) {
MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm); MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm);
AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz)); AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz));
Arr* ra = mut_fp(rm); Arr* ra = mut_fp(rm);
shSet_alloc(ra, rr, rsh); shSet(ra, rr, sh);
r = withFill(taga(ra), getFillQ(x)); r = withFill(taga(ra), getFillQ(x));
decG(x); goto ret; decG(x); goto ret;
} }
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
&& !dup && na>=2 && p[na-1]==ar-2 && p[na-2]==ar-1 // Last two axes transposed && na>=2 && st[na-1]==rsh[na-2] && st[na-2]==1 // Last two axes transposed
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // And large-ish && rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // And large-ish
&& xe!=el_B) { && xe!=el_B) {
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3]; TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
@ -280,26 +281,19 @@ B transp_c2(B t, B w, B x) {
usz hw = h*w*csz; usz hw = h*w*csz;
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h)); AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h));
} }
shSet_alloc(ra, rr, rsh); shSet(ra, rr, sh);
r = taga(ra); r = taga(ra);
decG(x); goto ret; decG(x); goto ret;
} }
#undef AXIS_LOOP #undef AXIS_LOOP
// Reshape x for selection, collapsing ar axes // Reshape x for selection
if (ar != 1) { ShArr* zsh = m_shArr(2);
ur zr = xr-ar+1; zsh->a[0] = c;
ShArr* zsh; zsh->a[1] = csz;
if (zr>1) { Arr* z = TI(x,slice)(x, 0, IA(x));
zsh = m_shArr(zr); arr_shSetU(z, 2, zsh);
zsh->a[0] = c; x = taga(z);
shcpy(zsh->a+1, xsh+ar, xr-ar);
}
Arr* z = TI(x,slice)(x, 0, IA(x));
if (zr>1) arr_shSetU(z, zr, zsh);
else arr_shVec(z);
x = taga(z);
}
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩 // (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
B ind = bi_N; B ind = bi_N;
for (ur k=na; k--; ) { for (ur k=na; k--; ) {
@ -308,6 +302,8 @@ B transp_c2(B t, B w, B x) {
else ind = M1C2(tbl, add, v, ind); else ind = M1C2(tbl, add, v, ind);
} }
r = C2(select, ind, x); r = C2(select, ind, x);
if (rr>1) arr_shReplace(a(r), rr, sh);
else { decSh(v(r)); arr_shVec(a(r)); }
ret:; ret:;
TFREE(p); TFREE(p);