Reorder Axes translated from runtime version
This commit is contained in:
parent
432b4eaaa6
commit
b555e3c035
@ -13,7 +13,7 @@
|
|||||||
// COULD use half-width or smaller kernels to improve odd sizes
|
// COULD use half-width or smaller kernels to improve odd sizes
|
||||||
// Scalar transpose or loop used for overhang of 1
|
// Scalar transpose or loop used for overhang of 1
|
||||||
|
|
||||||
// Reorder Axes: self-hosted runtime (based on +⌜ and ⊏, not that slow)
|
// Reorder Axes: generate indices and select with +⌜ and ⊏
|
||||||
|
|
||||||
// Transpose inverse ⍉⁼
|
// Transpose inverse ⍉⁼
|
||||||
// Same as ⍉ for a rank ≤2 argument
|
// Same as ⍉ for a rank ≤2 argument
|
||||||
@ -25,6 +25,7 @@
|
|||||||
#include "../utils/each.h"
|
#include "../utils/each.h"
|
||||||
#include "../utils/talloc.h"
|
#include "../utils/talloc.h"
|
||||||
#include "../builtins.h"
|
#include "../builtins.h"
|
||||||
|
#include "../utils/calls.h"
|
||||||
|
|
||||||
#ifdef __BMI2__
|
#ifdef __BMI2__
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
@ -51,7 +52,6 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
extern B rt_transp;
|
|
||||||
B transp_c1(B t, B x) {
|
B transp_c1(B t, B x) {
|
||||||
if (RARE(isAtm(x))) return m_atomUnit(x);
|
if (RARE(isAtm(x))) return m_atomUnit(x);
|
||||||
ur xr = RNK(x);
|
ur xr = RNK(x);
|
||||||
@ -169,7 +169,116 @@ B transp_c1(B t, B x) {
|
|||||||
}
|
}
|
||||||
decG(x); return taga(toBit? (Arr*)cpyBitArr(taga(r)) : r);
|
decG(x); return taga(toBit? (Arr*)cpyBitArr(taga(r)) : r);
|
||||||
}
|
}
|
||||||
B transp_c2(B t, B w, B x) { return c2rt(transp, w, x); }
|
|
||||||
|
B mul_c2(B,B,B);
|
||||||
|
B ud_c1(B,B);
|
||||||
|
B tbl_c2(Md1D*,B,B);
|
||||||
|
B select_c2(B,B,B);
|
||||||
|
|
||||||
|
B transp_c2(B t, B w, B x) {
|
||||||
|
usz wia=1;
|
||||||
|
if (isArr(w)) {
|
||||||
|
if (RNK(w)>1) thrM("⍉: 𝕨 must have rank at most 1");
|
||||||
|
wia = IA(w);
|
||||||
|
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
|
||||||
|
}
|
||||||
|
ur xr;
|
||||||
|
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
|
||||||
|
|
||||||
|
// Axis permutation
|
||||||
|
TALLOC(ur, p, xr);
|
||||||
|
if (isAtm(w)) {
|
||||||
|
usz a=o2s(w);
|
||||||
|
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
|
||||||
|
p[0] = a;
|
||||||
|
} else {
|
||||||
|
SGetU(w)
|
||||||
|
for (usz i=0; i<wia; i++) {
|
||||||
|
usz a=o2s(GetU(w, i));
|
||||||
|
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
|
||||||
|
p[i] = a;
|
||||||
|
}
|
||||||
|
decG(w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute shape for the given axes
|
||||||
|
usz* xsh = SH(x);
|
||||||
|
TALLOC(usz, rsh, xr);
|
||||||
|
usz dup = 0, max = 0;
|
||||||
|
usz no_sh = -(usz)1;
|
||||||
|
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
|
||||||
|
for (usz i=0; i<wia; i++) {
|
||||||
|
ur j=p[i];
|
||||||
|
usz xl=xsh[i], l=rsh[j];
|
||||||
|
dup += l!=no_sh;
|
||||||
|
max = j>max? j : max;
|
||||||
|
if (xl<l) rsh[j]=xl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill in remaining axes and check for missing ones
|
||||||
|
ur rr = xr-dup;
|
||||||
|
if (max >= rr) thrF("⍉: Skipped result axis");
|
||||||
|
if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) {
|
||||||
|
p[i] = j;
|
||||||
|
rsh[j] = xsh[i];
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
B r;
|
||||||
|
|
||||||
|
// Empty result
|
||||||
|
if (IA(x) == 0) {
|
||||||
|
Arr* ra = m_fillarrpEmpty(getFillQ(x));
|
||||||
|
if (RARE(rr <= 1)) {
|
||||||
|
arr_shVec(ra);
|
||||||
|
} else {
|
||||||
|
ShArr* sh=m_shArr(rr);
|
||||||
|
shcpy(sh->a, rsh, rr);
|
||||||
|
arr_shSetU(ra, rr, sh);
|
||||||
|
}
|
||||||
|
decG(x);
|
||||||
|
r = taga(ra); goto ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of axes that move
|
||||||
|
ur ar = max+1+dup;
|
||||||
|
if (ar == 1) { r = x; goto ret; }
|
||||||
|
// Add up stride for each axis
|
||||||
|
TALLOC(u64, st, rr);
|
||||||
|
for (usz j=0; j<rr; j++) st[j] = 0;
|
||||||
|
usz c = 1;
|
||||||
|
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
|
||||||
|
|
||||||
|
// Reshape x for selection, collapsing ar axes
|
||||||
|
if (ar != 1) {
|
||||||
|
ur zr = xr-ar+1;
|
||||||
|
ShArr* zsh;
|
||||||
|
if (zr>1) {
|
||||||
|
zsh = m_shArr(zr);
|
||||||
|
zsh->a[0] = c;
|
||||||
|
shcpy(zsh->a+1, xsh+ar, xr-ar);
|
||||||
|
}
|
||||||
|
Arr* z = TI(x,slice)(x, 0, IA(x));
|
||||||
|
if (zr>1) arr_shSetU(z, zr, zsh);
|
||||||
|
else arr_shVec(z);
|
||||||
|
x = taga(z);
|
||||||
|
}
|
||||||
|
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
|
||||||
|
B ind = bi_N;
|
||||||
|
for (ur k=ar-dup; k--; ) {
|
||||||
|
B v = C2(mul, m_f64(st[k]), C1(ud, m_f64(rsh[k])));
|
||||||
|
if (q_N(ind)) ind = v;
|
||||||
|
else ind = M1C2(tbl, add, v, ind);
|
||||||
|
}
|
||||||
|
TFREE(st);
|
||||||
|
r = C2(select, ind, x);
|
||||||
|
|
||||||
|
ret:;
|
||||||
|
TFREE(rsh);
|
||||||
|
TFREE(p);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
B transp_im(B t, B x) {
|
B transp_im(B t, B x) {
|
||||||
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
|
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
|
||||||
|
|||||||
@ -107,7 +107,7 @@ B comp_currSrc;
|
|||||||
B comp_currRe; // ⟨REPL mode ⋄ scope ⋄ compiler ⋄ runtime ⋄ glyphs ⋄ sysval names ⋄ sysval values⟩
|
B comp_currRe; // ⟨REPL mode ⋄ scope ⋄ compiler ⋄ runtime ⋄ glyphs ⋄ sysval names ⋄ sysval values⟩
|
||||||
|
|
||||||
B rt_undo, rt_select, rt_slash, rt_insert, rt_depth,
|
B rt_undo, rt_select, rt_slash, rt_insert, rt_depth,
|
||||||
rt_group, rt_under, rt_find, rt_transp;
|
rt_group, rt_under, rt_find;
|
||||||
Block* load_compObj(B x, B src, B path, Scope* sc) { // consumes x,src
|
Block* load_compObj(B x, B src, B path, Scope* sc) { // consumes x,src
|
||||||
SGet(x)
|
SGet(x)
|
||||||
usz xia = IA(x);
|
usz xia = IA(x);
|
||||||
@ -442,7 +442,6 @@ void load_init() { // very last init function
|
|||||||
rt_group = Get(rtObjRaw, n_group );
|
rt_group = Get(rtObjRaw, n_group );
|
||||||
rt_under = Get(rtObjRaw, n_under );
|
rt_under = Get(rtObjRaw, n_under );
|
||||||
rt_find = Get(rtObjRaw, n_find );
|
rt_find = Get(rtObjRaw, n_find );
|
||||||
rt_transp = Get(rtObjRaw, n_transp );
|
|
||||||
rt_depth = Get(rtObjRaw, n_depth );
|
rt_depth = Get(rtObjRaw, n_depth );
|
||||||
rt_insert = Get(rtObjRaw, n_insert );
|
rt_insert = Get(rtObjRaw, n_insert );
|
||||||
|
|
||||||
@ -485,7 +484,7 @@ void load_init() { // very last init function
|
|||||||
}
|
}
|
||||||
load_rtObj = frtObj;
|
load_rtObj = frtObj;
|
||||||
load_compArg = m_hVec2(load_rtObj, incG(bi_sys));
|
load_compArg = m_hVec2(load_rtObj, incG(bi_sys));
|
||||||
rt_select=rt_slash=rt_group=rt_find=rt_transp=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
|
rt_select=rt_slash=rt_group=rt_find=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
|
||||||
rt_undo=rt_insert = incByG(bi_invalidMd1, 2);
|
rt_undo=rt_insert = incByG(bi_invalidMd1, 2);
|
||||||
rt_under=rt_depth = incByG(bi_invalidMd2, 2);
|
rt_under=rt_depth = incByG(bi_invalidMd2, 2);
|
||||||
rt_invFnRegFn=rt_invFnSwapFn = invalidFn_c1;
|
rt_invFnRegFn=rt_invFnSwapFn = invalidFn_c1;
|
||||||
@ -497,7 +496,6 @@ void load_init() { // very last init function
|
|||||||
gc_add(rt_group);
|
gc_add(rt_group);
|
||||||
gc_add(rt_under);
|
gc_add(rt_under);
|
||||||
gc_add(rt_find);
|
gc_add(rt_find);
|
||||||
gc_add(rt_transp);
|
|
||||||
gc_add(rt_depth);
|
gc_add(rt_depth);
|
||||||
gc_add(rt_insert);
|
gc_add(rt_insert);
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user