Reorder Axes translated from runtime version
This commit is contained in:
parent
432b4eaaa6
commit
b555e3c035
@ -13,7 +13,7 @@
|
||||
// COULD use half-width or smaller kernels to improve odd sizes
|
||||
// Scalar transpose or loop used for overhang of 1
|
||||
|
||||
// Reorder Axes: self-hosted runtime (based on +⌜ and ⊏, not that slow)
|
||||
// Reorder Axes: generate indices and select with +⌜ and ⊏
|
||||
|
||||
// Transpose inverse ⍉⁼
|
||||
// Same as ⍉ for a rank ≤2 argument
|
||||
@ -25,6 +25,7 @@
|
||||
#include "../utils/each.h"
|
||||
#include "../utils/talloc.h"
|
||||
#include "../builtins.h"
|
||||
#include "../utils/calls.h"
|
||||
|
||||
#ifdef __BMI2__
|
||||
#include <immintrin.h>
|
||||
@ -51,7 +52,6 @@
|
||||
#endif
|
||||
|
||||
|
||||
extern B rt_transp;
|
||||
B transp_c1(B t, B x) {
|
||||
if (RARE(isAtm(x))) return m_atomUnit(x);
|
||||
ur xr = RNK(x);
|
||||
@ -169,7 +169,116 @@ B transp_c1(B t, B x) {
|
||||
}
|
||||
decG(x); return taga(toBit? (Arr*)cpyBitArr(taga(r)) : r);
|
||||
}
|
||||
B transp_c2(B t, B w, B x) { return c2rt(transp, w, x); }
|
||||
|
||||
B mul_c2(B,B,B);
|
||||
B ud_c1(B,B);
|
||||
B tbl_c2(Md1D*,B,B);
|
||||
B select_c2(B,B,B);
|
||||
|
||||
B transp_c2(B t, B w, B x) {
|
||||
usz wia=1;
|
||||
if (isArr(w)) {
|
||||
if (RNK(w)>1) thrM("⍉: 𝕨 must have rank at most 1");
|
||||
wia = IA(w);
|
||||
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
|
||||
}
|
||||
ur xr;
|
||||
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
|
||||
|
||||
// Axis permutation
|
||||
TALLOC(ur, p, xr);
|
||||
if (isAtm(w)) {
|
||||
usz a=o2s(w);
|
||||
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
|
||||
p[0] = a;
|
||||
} else {
|
||||
SGetU(w)
|
||||
for (usz i=0; i<wia; i++) {
|
||||
usz a=o2s(GetU(w, i));
|
||||
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
|
||||
p[i] = a;
|
||||
}
|
||||
decG(w);
|
||||
}
|
||||
|
||||
// compute shape for the given axes
|
||||
usz* xsh = SH(x);
|
||||
TALLOC(usz, rsh, xr);
|
||||
usz dup = 0, max = 0;
|
||||
usz no_sh = -(usz)1;
|
||||
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
|
||||
for (usz i=0; i<wia; i++) {
|
||||
ur j=p[i];
|
||||
usz xl=xsh[i], l=rsh[j];
|
||||
dup += l!=no_sh;
|
||||
max = j>max? j : max;
|
||||
if (xl<l) rsh[j]=xl;
|
||||
}
|
||||
|
||||
// fill in remaining axes and check for missing ones
|
||||
ur rr = xr-dup;
|
||||
if (max >= rr) thrF("⍉: Skipped result axis");
|
||||
if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) {
|
||||
p[i] = j;
|
||||
rsh[j] = xsh[i];
|
||||
i++;
|
||||
}
|
||||
|
||||
B r;
|
||||
|
||||
// Empty result
|
||||
if (IA(x) == 0) {
|
||||
Arr* ra = m_fillarrpEmpty(getFillQ(x));
|
||||
if (RARE(rr <= 1)) {
|
||||
arr_shVec(ra);
|
||||
} else {
|
||||
ShArr* sh=m_shArr(rr);
|
||||
shcpy(sh->a, rsh, rr);
|
||||
arr_shSetU(ra, rr, sh);
|
||||
}
|
||||
decG(x);
|
||||
r = taga(ra); goto ret;
|
||||
}
|
||||
|
||||
// Number of axes that move
|
||||
ur ar = max+1+dup;
|
||||
if (ar == 1) { r = x; goto ret; }
|
||||
// Add up stride for each axis
|
||||
TALLOC(u64, st, rr);
|
||||
for (usz j=0; j<rr; j++) st[j] = 0;
|
||||
usz c = 1;
|
||||
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
|
||||
|
||||
// Reshape x for selection, collapsing ar axes
|
||||
if (ar != 1) {
|
||||
ur zr = xr-ar+1;
|
||||
ShArr* zsh;
|
||||
if (zr>1) {
|
||||
zsh = m_shArr(zr);
|
||||
zsh->a[0] = c;
|
||||
shcpy(zsh->a+1, xsh+ar, xr-ar);
|
||||
}
|
||||
Arr* z = TI(x,slice)(x, 0, IA(x));
|
||||
if (zr>1) arr_shSetU(z, zr, zsh);
|
||||
else arr_shVec(z);
|
||||
x = taga(z);
|
||||
}
|
||||
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
|
||||
B ind = bi_N;
|
||||
for (ur k=ar-dup; k--; ) {
|
||||
B v = C2(mul, m_f64(st[k]), C1(ud, m_f64(rsh[k])));
|
||||
if (q_N(ind)) ind = v;
|
||||
else ind = M1C2(tbl, add, v, ind);
|
||||
}
|
||||
TFREE(st);
|
||||
r = C2(select, ind, x);
|
||||
|
||||
ret:;
|
||||
TFREE(rsh);
|
||||
TFREE(p);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
B transp_im(B t, B x) {
|
||||
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
|
||||
|
||||
@ -107,7 +107,7 @@ B comp_currSrc;
|
||||
B comp_currRe; // ⟨REPL mode ⋄ scope ⋄ compiler ⋄ runtime ⋄ glyphs ⋄ sysval names ⋄ sysval values⟩
|
||||
|
||||
B rt_undo, rt_select, rt_slash, rt_insert, rt_depth,
|
||||
rt_group, rt_under, rt_find, rt_transp;
|
||||
rt_group, rt_under, rt_find;
|
||||
Block* load_compObj(B x, B src, B path, Scope* sc) { // consumes x,src
|
||||
SGet(x)
|
||||
usz xia = IA(x);
|
||||
@ -442,7 +442,6 @@ void load_init() { // very last init function
|
||||
rt_group = Get(rtObjRaw, n_group );
|
||||
rt_under = Get(rtObjRaw, n_under );
|
||||
rt_find = Get(rtObjRaw, n_find );
|
||||
rt_transp = Get(rtObjRaw, n_transp );
|
||||
rt_depth = Get(rtObjRaw, n_depth );
|
||||
rt_insert = Get(rtObjRaw, n_insert );
|
||||
|
||||
@ -485,7 +484,7 @@ void load_init() { // very last init function
|
||||
}
|
||||
load_rtObj = frtObj;
|
||||
load_compArg = m_hVec2(load_rtObj, incG(bi_sys));
|
||||
rt_select=rt_slash=rt_group=rt_find=rt_transp=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
|
||||
rt_select=rt_slash=rt_group=rt_find=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
|
||||
rt_undo=rt_insert = incByG(bi_invalidMd1, 2);
|
||||
rt_under=rt_depth = incByG(bi_invalidMd2, 2);
|
||||
rt_invFnRegFn=rt_invFnSwapFn = invalidFn_c1;
|
||||
@ -497,7 +496,6 @@ void load_init() { // very last init function
|
||||
gc_add(rt_group);
|
||||
gc_add(rt_under);
|
||||
gc_add(rt_find);
|
||||
gc_add(rt_transp);
|
||||
gc_add(rt_depth);
|
||||
gc_add(rt_insert);
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user