Full simplification pass for Reorder Axes

This commit is contained in:
Marshall Lochbaum 2023-03-29 20:08:36 -04:00
parent d492cd0865
commit 597e25af4d

View File

@ -172,19 +172,23 @@ B transp_c2(B t, B w, B x) {
decG(w);
}
B r;
// Compute shape for the given axes
usz* xsh = SH(x);
usz *rsh = (usz*)(p + xr); // Length xr
usz dup = 0, max = 0;
usz dup = 0, max = 0, id = 0;
usz no_sh = -(usz)1;
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
for (usz i=0; i<wia; i++) {
ur j=p[i];
usz xl=xsh[i], l=rsh[j];
dup += l!=no_sh;
id += i==j;
max = j>max? j : max;
if (xl<l) rsh[j]=xl;
}
if (id == wia) { r = x; goto ret; }
// Fill in remaining axes and check for missing ones
ur rr = xr-dup;
@ -201,8 +205,6 @@ B transp_c2(B t, B w, B x) {
shcpy(sh->a, rsh, rr);
}
B r;
// Empty result
if (IA(x) == 0) {
Arr* ra = m_fillarrpEmpty(getFillQ(x));
@ -211,23 +213,33 @@ B transp_c2(B t, B w, B x) {
r = taga(ra); goto ret;
}
// Number of axes that move
ur ar = max+1+dup;
if (!dup) while (ar>1 && p[ar-1]==ar-1) ar--; // Unmoved trailing
if (ar <= 1) { if (rr>1) ptr_dec(sh); r = x; goto ret; }
ur na = ar - dup;
// Add up stride for each axis
usz* st = rsh + xr; // Length ar
for (usz j=0; j<ar; j++) st[j] = 0;
usz c = 1;
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
ur na = max + 1; // Number of result axes that moved
usz* st = rsh + xr; // Length na
for (usz j=0; j<na; j++) st[j] = 0;
usz csz = shProd(xsh, na+dup, xr);
for (usz i=na+dup, c=csz; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
// Simplify axis structure
// p is unused now; work only on csz, rsh, and st
usz *lp = &csz; usz sz = csz;
usz na0=na; usz* rsh0=rsh; usz* st0=st; rsh+=na0; st+=na0; na=0;
for (usz i=na0; i--; ) {
usz l = rsh0[i]; if (l==1) continue; // Ignore
usz s = st0[i]; if (s==sz) { *lp*=l; sz*=l; continue; } // Combine with lower
na++; *--rsh=l; *--st=s; lp=rsh; sz=l*s;
}
// Turned out trivial
if (na == 0) {
Arr* ra = TI(x,slice)(x, 0, csz);
shSet(ra, rr, sh);
r = taga(ra); goto ret;
}
u8 xe = TI(x,elType);
usz csz = shProd(xsh, ar, xr);
#define AXIS_LOOP(N_AX, I_INC, DO_INNER) \
ur a0 = N_AX - 1; \
for (usz i=0; i<na; i++) st[i] *= csz; \
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0; \
usz* ri = st+na; for (usz i=0; i<a0; i++) ri[i]=0; \
usz l = rsh[a0]; \
for (usz i=0, j0=0;;) { \
/* Hardcode one innermost loop: assume N_AX>=1 */ \
@ -280,8 +292,9 @@ B transp_c2(B t, B w, B x) {
usz i_skip = (w-1)*hs*csz;
usz end = rf*csz - i_skip;
ur a0 = na - 1;
for (usz i=0; i<na; i++) st[i] *= csz;
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0;
if (xlw<3) for (usz i=0; i<na; i++) st[i] >>= 3-xlw;
else if (xlw>3) for (usz i=0; i<na; i++) st[i] <<= xlw-3;
usz* ri = st+na; for (usz i=0; i<a0; i++) ri[i]=0;
for (usz i=0, j=0;;) {
tran(rp+i, xp+j, w, h, ws, hs);
i += h*csz;
@ -304,7 +317,7 @@ B transp_c2(B t, B w, B x) {
// Reshape x for selection
ShArr* zsh = m_shArr(2);
zsh->a[0] = c;
zsh->a[0] = IA(x)/csz;
zsh->a[1] = csz;
Arr* z = TI(x,slice)(x, 0, IA(x));
arr_shSetU(z, 2, zsh);
@ -312,7 +325,7 @@ B transp_c2(B t, B w, B x) {
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
B ind = bi_N;
for (ur k=na; k--; ) {
B v = C2(mul, m_f64(st[k]), C1(ud, m_f64(rsh[k])));
B v = C2(mul, m_f64(st[k]/csz), C1(ud, m_f64(rsh[k])));
if (q_N(ind)) ind = v;
else ind = M1C2(tbl, add, v, ind);
}