diff --git a/src/builtins/transpose.c b/src/builtins/transpose.c index 0390b2ce..8bb00681 100644 --- a/src/builtins/transpose.c +++ b/src/builtins/transpose.c @@ -145,14 +145,9 @@ B ud_c1(B,B); B tbl_c2(Md1D*,B,B); B select_c2(B,B,B); -static void shSet_alloc(Arr* ra, ur rr, usz* rsh) { - if (RARE(rr <= 1)) { - arr_shVec(ra); - } else { - ShArr* sh=m_shArr(rr); - shcpy(sh->a, rsh, rr); - arr_shSetU(ra, rr, sh); - } +static void shSet(Arr* ra, ur rr, ShArr* sh) { + if (RARE(rr <= 1)) arr_shVec(ra); + else arr_shSetU(ra, rr, sh); } B transp_c2(B t, B w, B x) { @@ -182,7 +177,7 @@ B transp_c2(B t, B w, B x) { decG(w); } - // compute shape for the given axes + // Compute shape for the given axes usz* xsh = SH(x); usz *rsh = (usz*)(p + xr); // Length xr usz dup = 0, max = 0; @@ -196,7 +191,7 @@ B transp_c2(B t, B w, B x) { if (xl= rr) thrF("⍉: Skipped result axis"); if (wia 1)) { // Not all duplicates + sh = m_shArr(rr); + shcpy(sh->a, rsh, rr); + } B r; // Empty result if (IA(x) == 0) { Arr* ra = m_fillarrpEmpty(getFillQ(x)); - shSet_alloc(ra, rr, rsh); + shSet(ra, rr, sh); decG(x); r = taga(ra); goto ret; } @@ -218,11 +219,11 @@ B transp_c2(B t, B w, B x) { // Number of axes that move ur ar = max+1+dup; if (!dup) while (ar>1 && p[ar-1]==ar-1) ar--; // Unmoved trailing - if (ar <= 1) { r = x; goto ret; } + if (ar <= 1) { if (rr>1) ptr_dec(sh); r = x; goto ret; } ur na = ar - dup; // Add up stride for each axis - usz* st = rsh + xr; // Length rr - for (usz j=0; j=1 */ \ @@ -257,12 +258,12 @@ B transp_c2(B t, B w, B x) { MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm); AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz)); Arr* ra = mut_fp(rm); - shSet_alloc(ra, rr, rsh); + shSet(ra, rr, sh); r = withFill(taga(ra), getFillQ(x)); decG(x); goto ret; } if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<=8 // CPU-sized cells - && !dup && na>=2 && p[na-1]==ar-2 && p[na-2]==ar-1 // Last two axes transposed + && na>=2 && st[na-1]==rsh[na-2] && st[na-2]==1 // Last two axes transposed && rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // And large-ish && xe!=el_B) { TranspFn tran = transposeFns[CTZ(csz<1) { - zsh = m_shArr(zr); - zsh->a[0] = c; - shcpy(zsh->a+1, xsh+ar, xr-ar); - } - Arr* z = TI(x,slice)(x, 0, IA(x)); - if (zr>1) arr_shSetU(z, zr, zsh); - else arr_shVec(z); - x = taga(z); - } + // Reshape x for selection + ShArr* zsh = m_shArr(2); + zsh->a[0] = c; + zsh->a[1] = csz; + Arr* z = TI(x,slice)(x, 0, IA(x)); + arr_shSetU(z, 2, zsh); + x = taga(z); // (+⌜´st×⟜↕¨rsh)⊏⥊𝕩 B ind = bi_N; for (ur k=na; k--; ) { @@ -308,6 +302,8 @@ B transp_c2(B t, B w, B x) { else ind = M1C2(tbl, add, v, ind); } r = C2(select, ind, x); + if (rr>1) arr_shReplace(a(r), rr, sh); + else { decSh(v(r)); arr_shVec(a(r)); } ret:; TFREE(p);