Under for invertible Reorder Axes

This commit is contained in:
Marshall Lochbaum 2023-03-30 19:03:24 -04:00
parent d53f3dbd72
commit 9fe6a2e7b2

View File

@ -25,7 +25,10 @@
// COULD convert boolean to integer for some axis reorderings
// SHOULD have a small-subarray transposer using one or a few shuffles
// Transpose inverse ⍉⁼𝕩: data movement of ⍉ with different shape logic
// ⍉⁼𝕩: data movement of ⍉ with different shape logic
// 𝕨⍉⁼𝕩: compute inverse 𝕨, length 1+⌈´𝕨
// Under Transpose supports invertible cases
// SHOULD implement Under with duplicate axes, maybe as Under Select
// COULD implement fast ⍉⍟n
// SHOULD convert ⍉ with rank to a Reorder Axes call
@ -379,10 +382,8 @@ B transp_uc1(B t, B o, B x) {
return transp_im(m_f64(0), c1(o, transp_c1(t, x)));
}
B transp_ix(B t, B w, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
ur xr=RNK(x);
// Consumes w; return bi_N if w contained duplicates
static B invert_transp_w(B w, ur xr) {
if (isAtm(w)) {
if (xr<1) thrM("⍉⁼: Length of 𝕨 must be at most rank of 𝕩");
usz a=o2s(w);
@ -392,7 +393,7 @@ B transp_ix(B t, B w, B x) {
} else {
if (RNK(w)>1) thrM("⍉⁼: 𝕨 must have rank at most 1");
usz wia = IA(w);
if (wia==0) { decG(w); return x; }
if (wia==0) return w;
if (xr<wia) thrM("⍉⁼: Length of 𝕨 must be at most rank of 𝕩");
SGetU(w)
TALLOC(ur, p, xr);
@ -401,7 +402,7 @@ B transp_ix(B t, B w, B x) {
for (usz i=0; i<wia; i++) {
usz a=o2s(GetU(w, i));
if (a>=xr) thrF("⍉⁼: Axis %s does not exist (%i≡=𝕩)", a, xr);
if (p[i]!=xr) thrM("⍉⁼: Duplicate axes");
if (p[a]!=xr) { TFREE(p); decG(w); return bi_N; } // Handled by caller
max = a>max? a : max;
p[a] = i;
}
@ -411,11 +412,25 @@ B transp_ix(B t, B w, B x) {
for (usz i=0, j=wia; i<n; i++) wp[i] = p[i]<xr? p[i] : j++;
TFREE(p);
}
return w;
}
B transp_ix(B t, B w, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
w = invert_transp_w(w, RNK(x));
if (q_N(w)) thrM("⍉⁼: Duplicate axes");
return C2(transp, w, x);
}
B transp_ucw(B t, B o, B w, B x) {
B wi = invert_transp_w(inc(w), isAtm(x)? 0 : RNK(x));
if (q_N(wi)) return def_fn_ucw(t, o, w, x); // Duplicate axes
return C2(transp, wi, c1(o, C2(transp, w, x)));
}
void transp_init(void) {
c(BFn,bi_transp)->im = transp_im;
c(BFn,bi_transp)->ix = transp_ix;
c(BFn,bi_transp)->uc1 = transp_uc1;
c(BFn,bi_transp)->ucw = transp_ucw;
}