From b555e3c0350aa1aee7eebfc2c3b3ccdbc04426a6 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Thu, 23 Mar 2023 20:56:25 -0400 Subject: [PATCH] Reorder Axes translated from runtime version --- src/builtins/transpose.c | 115 ++++++++++++++++++++++++++++++++++++++- src/load.c | 6 +- 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/src/builtins/transpose.c b/src/builtins/transpose.c index 3cadb5be..f6f8ce50 100644 --- a/src/builtins/transpose.c +++ b/src/builtins/transpose.c @@ -13,7 +13,7 @@ // COULD use half-width or smaller kernels to improve odd sizes // Scalar transpose or loop used for overhang of 1 -// Reorder Axes: self-hosted runtime (based on +⌜ and ⊏, not that slow) +// Reorder Axes: generate indices and select with +⌜ and ⊏ // Transpose inverse ⍉⁼ // Same as ⍉ for a rank ≤2 argument @@ -25,6 +25,7 @@ #include "../utils/each.h" #include "../utils/talloc.h" #include "../builtins.h" +#include "../utils/calls.h" #ifdef __BMI2__ #include @@ -51,7 +52,6 @@ #endif -extern B rt_transp; B transp_c1(B t, B x) { if (RARE(isAtm(x))) return m_atomUnit(x); ur xr = RNK(x); @@ -169,7 +169,116 @@ B transp_c1(B t, B x) { } decG(x); return taga(toBit? (Arr*)cpyBitArr(taga(r)) : r); } -B transp_c2(B t, B w, B x) { return c2rt(transp, w, x); } + +B mul_c2(B,B,B); +B ud_c1(B,B); +B tbl_c2(Md1D*,B,B); +B select_c2(B,B,B); + +B transp_c2(B t, B w, B x) { + usz wia=1; + if (isArr(w)) { + if (RNK(w)>1) thrM("⍉: 𝕨 must have rank at most 1"); + wia = IA(w); + if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); } + } + ur xr; + if (isAtm(x) || (xr=RNK(x))=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr); + p[0] = a; + } else { + SGetU(w) + for (usz i=0; i=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr); + p[i] = a; + } + decG(w); + } + + // compute shape for the given axes + usz* xsh = SH(x); + TALLOC(usz, rsh, xr); + usz dup = 0, max = 0; + usz no_sh = -(usz)1; + for (usz j=0; jmax? j : max; + if (xl= rr) thrF("⍉: Skipped result axis"); + if (wiaa, rsh, rr); + arr_shSetU(ra, rr, sh); + } + decG(x); + r = taga(ra); goto ret; + } + + // Number of axes that move + ur ar = max+1+dup; + if (ar == 1) { r = x; goto ret; } + // Add up stride for each axis + TALLOC(u64, st, rr); + for (usz j=0; j1) { + zsh = m_shArr(zr); + zsh->a[0] = c; + shcpy(zsh->a+1, xsh+ar, xr-ar); + } + Arr* z = TI(x,slice)(x, 0, IA(x)); + if (zr>1) arr_shSetU(z, zr, zsh); + else arr_shVec(z); + x = taga(z); + } + // (+⌜´st×⟜↕¨rsh)⊏⥊𝕩 + B ind = bi_N; + for (ur k=ar-dup; k--; ) { + B v = C2(mul, m_f64(st[k]), C1(ud, m_f64(rsh[k]))); + if (q_N(ind)) ind = v; + else ind = M1C2(tbl, add, v, ind); + } + TFREE(st); + r = C2(select, ind, x); + + ret:; + TFREE(rsh); + TFREE(p); + return r; +} + B transp_im(B t, B x) { if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom"); diff --git a/src/load.c b/src/load.c index f6485267..7fbdf0be 100644 --- a/src/load.c +++ b/src/load.c @@ -107,7 +107,7 @@ B comp_currSrc; B comp_currRe; // ⟨REPL mode ⋄ scope ⋄ compiler ⋄ runtime ⋄ glyphs ⋄ sysval names ⋄ sysval values⟩ B rt_undo, rt_select, rt_slash, rt_insert, rt_depth, - rt_group, rt_under, rt_find, rt_transp; + rt_group, rt_under, rt_find; Block* load_compObj(B x, B src, B path, Scope* sc) { // consumes x,src SGet(x) usz xia = IA(x); @@ -442,7 +442,6 @@ void load_init() { // very last init function rt_group = Get(rtObjRaw, n_group ); rt_under = Get(rtObjRaw, n_under ); rt_find = Get(rtObjRaw, n_find ); - rt_transp = Get(rtObjRaw, n_transp ); rt_depth = Get(rtObjRaw, n_depth ); rt_insert = Get(rtObjRaw, n_insert ); @@ -485,7 +484,7 @@ void load_init() { // very last init function } load_rtObj = frtObj; load_compArg = m_hVec2(load_rtObj, incG(bi_sys)); - rt_select=rt_slash=rt_group=rt_find=rt_transp=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7); + rt_select=rt_slash=rt_group=rt_find=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7); rt_undo=rt_insert = incByG(bi_invalidMd1, 2); rt_under=rt_depth = incByG(bi_invalidMd2, 2); rt_invFnRegFn=rt_invFnSwapFn = invalidFn_c1; @@ -497,7 +496,6 @@ void load_init() { // very last init function gc_add(rt_group); gc_add(rt_under); gc_add(rt_find); - gc_add(rt_transp); gc_add(rt_depth); gc_add(rt_insert);