Reorder Axes translated from runtime version

This commit is contained in:
Marshall Lochbaum 2023-03-23 20:56:25 -04:00
parent 432b4eaaa6
commit b555e3c035
2 changed files with 114 additions and 7 deletions

View File

@ -13,7 +13,7 @@
// COULD use half-width or smaller kernels to improve odd sizes
// Scalar transpose or loop used for overhang of 1
// Reorder Axes: self-hosted runtime (based on +⌜ and ⊏, not that slow)
// Reorder Axes: generate indices and select with +⌜ and ⊏
// Transpose inverse ⍉⁼
// Same as ⍉ for a rank ≤2 argument
@ -25,6 +25,7 @@
#include "../utils/each.h"
#include "../utils/talloc.h"
#include "../builtins.h"
#include "../utils/calls.h"
#ifdef __BMI2__
#include <immintrin.h>
@ -51,7 +52,6 @@
#endif
extern B rt_transp;
B transp_c1(B t, B x) {
if (RARE(isAtm(x))) return m_atomUnit(x);
ur xr = RNK(x);
@ -169,7 +169,116 @@ B transp_c1(B t, B x) {
}
decG(x); return taga(toBit? (Arr*)cpyBitArr(taga(r)) : r);
}
B transp_c2(B t, B w, B x) { return c2rt(transp, w, x); }
B mul_c2(B,B,B);
B ud_c1(B,B);
B tbl_c2(Md1D*,B,B);
B select_c2(B,B,B);
B transp_c2(B t, B w, B x) {
usz wia=1;
if (isArr(w)) {
if (RNK(w)>1) thrM("⍉: 𝕨 must have rank at most 1");
wia = IA(w);
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
}
ur xr;
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
// Axis permutation
TALLOC(ur, p, xr);
if (isAtm(w)) {
usz a=o2s(w);
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
p[0] = a;
} else {
SGetU(w)
for (usz i=0; i<wia; i++) {
usz a=o2s(GetU(w, i));
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
p[i] = a;
}
decG(w);
}
// compute shape for the given axes
usz* xsh = SH(x);
TALLOC(usz, rsh, xr);
usz dup = 0, max = 0;
usz no_sh = -(usz)1;
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
for (usz i=0; i<wia; i++) {
ur j=p[i];
usz xl=xsh[i], l=rsh[j];
dup += l!=no_sh;
max = j>max? j : max;
if (xl<l) rsh[j]=xl;
}
// fill in remaining axes and check for missing ones
ur rr = xr-dup;
if (max >= rr) thrF("⍉: Skipped result axis");
if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) {
p[i] = j;
rsh[j] = xsh[i];
i++;
}
B r;
// Empty result
if (IA(x) == 0) {
Arr* ra = m_fillarrpEmpty(getFillQ(x));
if (RARE(rr <= 1)) {
arr_shVec(ra);
} else {
ShArr* sh=m_shArr(rr);
shcpy(sh->a, rsh, rr);
arr_shSetU(ra, rr, sh);
}
decG(x);
r = taga(ra); goto ret;
}
// Number of axes that move
ur ar = max+1+dup;
if (ar == 1) { r = x; goto ret; }
// Add up stride for each axis
TALLOC(u64, st, rr);
for (usz j=0; j<rr; j++) st[j] = 0;
usz c = 1;
for (usz i=ar; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
// Reshape x for selection, collapsing ar axes
if (ar != 1) {
ur zr = xr-ar+1;
ShArr* zsh;
if (zr>1) {
zsh = m_shArr(zr);
zsh->a[0] = c;
shcpy(zsh->a+1, xsh+ar, xr-ar);
}
Arr* z = TI(x,slice)(x, 0, IA(x));
if (zr>1) arr_shSetU(z, zr, zsh);
else arr_shVec(z);
x = taga(z);
}
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
B ind = bi_N;
for (ur k=ar-dup; k--; ) {
B v = C2(mul, m_f64(st[k]), C1(ud, m_f64(rsh[k])));
if (q_N(ind)) ind = v;
else ind = M1C2(tbl, add, v, ind);
}
TFREE(st);
r = C2(select, ind, x);
ret:;
TFREE(rsh);
TFREE(p);
return r;
}
B transp_im(B t, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");

View File

@ -107,7 +107,7 @@ B comp_currSrc;
B comp_currRe; // ⟨REPL mode ⋄ scope ⋄ compiler ⋄ runtime ⋄ glyphs ⋄ sysval names ⋄ sysval values⟩
B rt_undo, rt_select, rt_slash, rt_insert, rt_depth,
rt_group, rt_under, rt_find, rt_transp;
rt_group, rt_under, rt_find;
Block* load_compObj(B x, B src, B path, Scope* sc) { // consumes x,src
SGet(x)
usz xia = IA(x);
@ -442,7 +442,6 @@ void load_init() { // very last init function
rt_group = Get(rtObjRaw, n_group );
rt_under = Get(rtObjRaw, n_under );
rt_find = Get(rtObjRaw, n_find );
rt_transp = Get(rtObjRaw, n_transp );
rt_depth = Get(rtObjRaw, n_depth );
rt_insert = Get(rtObjRaw, n_insert );
@ -485,7 +484,7 @@ void load_init() { // very last init function
}
load_rtObj = frtObj;
load_compArg = m_hVec2(load_rtObj, incG(bi_sys));
rt_select=rt_slash=rt_group=rt_find=rt_transp=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
rt_select=rt_slash=rt_group=rt_find=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
rt_undo=rt_insert = incByG(bi_invalidMd1, 2);
rt_under=rt_depth = incByG(bi_invalidMd2, 2);
rt_invFnRegFn=rt_invFnSwapFn = invalidFn_c1;
@ -497,7 +496,6 @@ void load_init() { // very last init function
gc_add(rt_group);
gc_add(rt_under);
gc_add(rt_find);
gc_add(rt_transp);
gc_add(rt_depth);
gc_add(rt_insert);