Full simplification pass for Reorder Axes
This commit is contained in:
parent
d492cd0865
commit
597e25af4d
@ -172,19 +172,23 @@ B transp_c2(B t, B w, B x) {
|
||||
decG(w);
|
||||
}
|
||||
|
||||
B r;
|
||||
|
||||
// Compute shape for the given axes
|
||||
usz* xsh = SH(x);
|
||||
usz *rsh = (usz*)(p + xr); // Length xr
|
||||
usz dup = 0, max = 0;
|
||||
usz dup = 0, max = 0, id = 0;
|
||||
usz no_sh = -(usz)1;
|
||||
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
|
||||
for (usz i=0; i<wia; i++) {
|
||||
ur j=p[i];
|
||||
usz xl=xsh[i], l=rsh[j];
|
||||
dup += l!=no_sh;
|
||||
id += i==j;
|
||||
max = j>max? j : max;
|
||||
if (xl<l) rsh[j]=xl;
|
||||
}
|
||||
if (id == wia) { r = x; goto ret; }
|
||||
|
||||
// Fill in remaining axes and check for missing ones
|
||||
ur rr = xr-dup;
|
||||
@ -201,8 +205,6 @@ B transp_c2(B t, B w, B x) {
|
||||
shcpy(sh->a, rsh, rr);
|
||||
}
|
||||
|
||||
B r;
|
||||
|
||||
// Empty result
|
||||
if (IA(x) == 0) {
|
||||
Arr* ra = m_fillarrpEmpty(getFillQ(x));
|
||||
@ -211,23 +213,33 @@ B transp_c2(B t, B w, B x) {
|
||||
r = taga(ra); goto ret;
|
||||
}
|
||||
|
||||
// Number of axes that move
|
||||
ur ar = max+1+dup;
|
||||
if (!dup) while (ar>1 && p[ar-1]==ar-1) ar--; // Unmoved trailing
|
||||
if (ar <= 1) { if (rr>1) ptr_dec(sh); r = x; goto ret; }
|
||||
ur na = ar - dup;
|
||||
// Add up stride for each axis
|
||||
usz* st = rsh + xr; // Length ar
|
||||
for (usz j=0; j<ar; j++) st[j] = 0;
|
||||
usz c = 1;
|
||||
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
|
||||
ur na = max + 1; // Number of result axes that moved
|
||||
usz* st = rsh + xr; // Length na
|
||||
for (usz j=0; j<na; j++) st[j] = 0;
|
||||
usz csz = shProd(xsh, na+dup, xr);
|
||||
for (usz i=na+dup, c=csz; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
|
||||
|
||||
// Simplify axis structure
|
||||
// p is unused now; work only on csz, rsh, and st
|
||||
usz *lp = &csz; usz sz = csz;
|
||||
usz na0=na; usz* rsh0=rsh; usz* st0=st; rsh+=na0; st+=na0; na=0;
|
||||
for (usz i=na0; i--; ) {
|
||||
usz l = rsh0[i]; if (l==1) continue; // Ignore
|
||||
usz s = st0[i]; if (s==sz) { *lp*=l; sz*=l; continue; } // Combine with lower
|
||||
na++; *--rsh=l; *--st=s; lp=rsh; sz=l*s;
|
||||
}
|
||||
// Turned out trivial
|
||||
if (na == 0) {
|
||||
Arr* ra = TI(x,slice)(x, 0, csz);
|
||||
shSet(ra, rr, sh);
|
||||
r = taga(ra); goto ret;
|
||||
}
|
||||
|
||||
u8 xe = TI(x,elType);
|
||||
usz csz = shProd(xsh, ar, xr);
|
||||
#define AXIS_LOOP(N_AX, I_INC, DO_INNER) \
|
||||
ur a0 = N_AX - 1; \
|
||||
for (usz i=0; i<na; i++) st[i] *= csz; \
|
||||
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0; \
|
||||
usz* ri = st+na; for (usz i=0; i<a0; i++) ri[i]=0; \
|
||||
usz l = rsh[a0]; \
|
||||
for (usz i=0, j0=0;;) { \
|
||||
/* Hardcode one innermost loop: assume N_AX>=1 */ \
|
||||
@ -280,8 +292,9 @@ B transp_c2(B t, B w, B x) {
|
||||
usz i_skip = (w-1)*hs*csz;
|
||||
usz end = rf*csz - i_skip;
|
||||
ur a0 = na - 1;
|
||||
for (usz i=0; i<na; i++) st[i] *= csz;
|
||||
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0;
|
||||
if (xlw<3) for (usz i=0; i<na; i++) st[i] >>= 3-xlw;
|
||||
else if (xlw>3) for (usz i=0; i<na; i++) st[i] <<= xlw-3;
|
||||
usz* ri = st+na; for (usz i=0; i<a0; i++) ri[i]=0;
|
||||
for (usz i=0, j=0;;) {
|
||||
tran(rp+i, xp+j, w, h, ws, hs);
|
||||
i += h*csz;
|
||||
@ -304,7 +317,7 @@ B transp_c2(B t, B w, B x) {
|
||||
|
||||
// Reshape x for selection
|
||||
ShArr* zsh = m_shArr(2);
|
||||
zsh->a[0] = c;
|
||||
zsh->a[0] = IA(x)/csz;
|
||||
zsh->a[1] = csz;
|
||||
Arr* z = TI(x,slice)(x, 0, IA(x));
|
||||
arr_shSetU(z, 2, zsh);
|
||||
@ -312,7 +325,7 @@ B transp_c2(B t, B w, B x) {
|
||||
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
|
||||
B ind = bi_N;
|
||||
for (ur k=na; k--; ) {
|
||||
B v = C2(mul, m_f64(st[k]), C1(ud, m_f64(rsh[k])));
|
||||
B v = C2(mul, m_f64(st[k]/csz), C1(ud, m_f64(rsh[k])));
|
||||
if (q_N(ind)) ind = v;
|
||||
else ind = M1C2(tbl, add, v, ind);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user