// Transpose and Reorder Axes (⍉) // Transpose // Boolean 𝕩: convert to integer // pdep or emulation for height 2; pext for width 2 // SHOULD switch to shuffles and modular permutation // SHOULD have bit matrix transpose kernel // CPU sizes: native or SIMD code // Large SIMD kernels used when they fit, overlapping for odd sizes // SSE, NEON i8: 8Γ—8 ; i16: 8Γ—8; i32: 4Γ—4; f64: scalar // AVX2 i8: 16Γ—16; i16: 16Γ—8; i32: 8Γ—8; f64: 4Γ—4 // COULD use half-width or smaller kernels to improve odd sizes // Scalar transpose or loop used for overhang of 1 // Partial kernels for at least half-kernel height/width // Short width: over-read, then skip some writes // Short height, i8 AVX2 only: overlap input rows and output words // Smaller heights and widths: dedicated kernels // Zipping, unzipping, shuffling for powers of 2 // Modular permutation for odd numbers, possibly times 2 // Reorder Axes // If 𝕨 indicates the identity permutation, return 𝕩 // Simplify: remove length-1 axes; coalesce adjacent and trailing axes // Empty result or trivial reordering: reshape 𝕩 // Large cells: slow outer loop plus mut_copy // CPU-sized cells, large last 𝕩 and result axes: strided 2D transposes // Otherwise, generate indices and select with +⌜ and ⊏ // SHOULD generate for a cell and virtualize the rest to save space // COULD decompose axis permutations to use 2D transpose when possible // COULD convert boolean to integer for some axis reorderings // SHOULD have a small-subarray transposer using one or a few shuffles // ⍉⁼𝕩: data movement of ⍉ with different shape logic // 𝕨⍉⁼𝕩: compute inverse 𝕨, length 1+βŒˆΒ΄π•¨ // Under Transpose supports invertible cases // SHOULD implement Under with duplicate axes, maybe as Under Select // β‰Λ˜π•© and kβ‰Λ˜π•© for number k: convert to 0β€Ώa⍉𝕩 // SHOULD convert ⍉ with rank to a Reorder Axes call // COULD implement fast β‰βŸn #include "../core.h" #include "../utils/mut.h" #include "../utils/talloc.h" #include "../builtins.h" #include "../utils/calls.h" #if __BMI2__ && __x86_64__ #if !SLOW_PDEP #define FAST_PDEP 1 #endif #include #if USE_VALGRIND #define _pdep_u64 vg_pdep_u64 #endif #endif #define DECL_BASE(T) \ static NOINLINE void transpose_##T(void* rv, void* xv, u64 bw, u64 bh, u64 w, u64 h) { \ T* rp=rv; T* xp=xv; \ if (bh>xlw && csz<=8) { // Require CPU-sized cells if (we!=xe) goto to_equal_types; assert(re!=el_B); void* rv; if (xlw==0) { u64* rp; r = m_bitarrp(&rp, ia); rv=rp; } else rv = m_tyarrp(&r,elWidth(re),ia,el2t(re)); interleave_fns[CTZ(csz<a, xsh+1, xr-1); sh->a[xr-1] = h; arr_shReplace(r, xr, sh); return taga(r); } usz w = xsh[1] * shProd(xsh, 2, xr); Arr* r = transpose_noshape(&x, ia, w, h); usz* rsh = arr_shAlloc(r, xr); if (xr==2) rsh[0] = w; else shcpy(rsh, SH(x)+1, xr-1); rsh[xr-1] = h; decG(x); return taga(r); } B mul_c2(B,B,B); B ud_c1(B,B); B tbl_c2(Md1D*,B,B); B select_c2(B,B,B); static void shSet(Arr* ra, ur rr, ShArr* sh) { if (RARE(rr <= 1)) arr_shVec(ra); else arr_shSetUO(ra, rr, sh); } B transp_c2(B t, B w, B x) { usz wia=1; if (isArr(w)) { if (RNK(w)>1) thrM("𝕨⍉𝕩: 𝕨 must have rank at most 1"); wia = IA(w); if (wia==0) { decG(w); return isArr(x)? x : m_unit(x); } } ur xr; if (isAtm(x) || (xr=RNK(x))=xr) thrF("𝕨⍉𝕩: Axis %s does not exist (%i≑=𝕩)", a, xr); if (a==xr-1) { TFREE(alloc); return C1(transp, x); } p[0] = a; } else { SGetU(w) for (usz i=0; i=xr) thrF("𝕨⍉𝕩: Axis %s does not exist (%i≑=𝕩)", a, xr); p[i] = a; } decG(w); } B r; // Compute shape for the given axes usz* xsh = SH(x); usz* rsh = ptr_roundUpToEl((usz*)(p + xr)); // Length xr usz dup = 0, max = 0, id = 0; usz no_sh = -(usz)1; for (usz j=0; jmax? j : max; if (xl= rr) thrF("𝕨⍉𝕩: Skipped result axis"); if (wia 1)) { // Not all duplicates sh = m_shArr(rr); shcpy(sh->a, rsh, rr); } // Empty result if (IA(x) == 0) { Arr* ra = emptyWithFill(getFillR(x)); shSet(ra, rr, sh); decG(x); r = taga(ra); goto ret; } // Add up stride for each axis ur na = max + 1; // Number of result axes that moved usz* st = rsh + xr; // Length na PLAINLOOP for (usz j=0; j=1 */ \ ur a = a0; \ usz str = st[a]; \ for (usz k=0; k= (32*8) >> xlw) { // cell >= 32 bytes usz ria = csz * shProd(rsh, 0, na); MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm); AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz)); Arr* ra = mut_fp(rm); shSet(ra, rr, sh); r = withFill(taga(ra), getFillR(x)); decG(x); goto ret; } #undef AXIS_LOOP if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<=8 // CPU-sized cells && xe!=el_B && na>=2) { // If some result axis has stride 1 (guaranteed if dup==0), then it // corresponds to the last argument axis and we have a strided // transpose swapping that with the last result axis usz rai = na-1; usz xai=rai; while (st[--xai]!=1) if (xai==0) goto skip_2d; if (rsh[xai]*rsh[rai] < (256*8) >> xlw) goto skip_2d; TranspFn tran = transposeFns[CTZ(csz<>= 3-xlw; else if (xlw>3) PLAINLOOP for (usz i=0; ia[0] = IA(x)/csz; zsh->a[1] = csz; x = taga(arr_shSetUG(customizeShape(x), 2, zsh)); // (+⌜´stΓ—βŸœβ†•Β¨rsh)⊏β₯Šπ•© B ind = bi_N; for (ur k=na; k--; ) { B v = C2(mul, m_usz(st[k]/csz), C1(ud, m_f64(rsh[k]))); if (q_N(ind)) ind = v; else ind = M1C2(tbl, add, v, ind); } r = C2(select, ind, x); Arr* ra = cpyWithShape(r); r = taga(ra); if (rr>1) arr_shReplace(ra, rr, sh); else { decSh((Value*)ra); arr_shVec(ra); } ret:; TFREE(alloc); return r; } B transp_im(B t, B x) { if (isAtm(x)) thrM("⍉⁼𝕩: 𝕩 must not be an atom"); ur xr = RNK(x); if (xr<=1) return x; usz ia = IA(x); usz* xsh = SH(x); usz w = xsh[xr-1]; if (ia==0 || w==1 || w==ia /*h==1*/) { Arr* r = cpyWithShape(x); ShArr* sh = m_shArr(xr); sh->a[0] = w; shcpy(sh->a+1, xsh, xr-1); arr_shReplace(r, xr, sh); return taga(r); } usz h = xsh[0] * shProd(xsh, 1, xr-1); Arr* r = transpose_noshape(&x, ia, w, h); usz* rsh = arr_shAlloc(r, xr); rsh[0] = w; if (xr==2) rsh[1] = h; else shcpy(rsh+1, SH(x), xr-1); decG(x); return taga(r); } B transp_uc1(B t, B o, B x) { return transp_im(m_f64(0), c1(o, transp_c1(t, x))); } // Consumes w; return bi_N if w contained duplicates static B invert_transp_w(B w, ur xr) { if (isAtm(w)) { if (xr<1) thrM("𝕨⍉⁼𝕩: Length of 𝕨 must be at most rank of 𝕩"); usz a=o2s(w); if (a>=xr) thrF("𝕨⍉⁼𝕩: Axis %s does not exist (%i≑=𝕩)", a, xr); i32* wp; w = m_i32arrv(&wp, a); PLAINLOOP for (usz i=0; i1) thrM("𝕨⍉⁼𝕩: 𝕨 must have rank at most 1"); usz wia = IA(w); if (wia==0) return w; if (xr=xr) thrF("𝕨⍉⁼𝕩: Axis %s does not exist (%i≑=𝕩)", a, xr); if (p[a]!=xr) { TFREE(p); decG(w); return bi_N; } // Handled by caller max = a>max? a : max; p[a] = i; } decG(w); usz n = max+1; i32* wp; w = m_i32arrv(&wp, n); for (usz i=0, j=wia; iim = transp_im; c(BFn,bi_transp)->ix = transp_ix; c(BFn,bi_transp)->uc1 = transp_uc1; c(BFn,bi_transp)->ucw = transp_ucw; }