Finish multidimensional join

This commit is contained in:
Marshall Lochbaum 2022-08-06 17:30:30 -04:00
parent 1d891388b5
commit 78f14be79d

View File

@ -853,7 +853,6 @@ B drop_c2(B t, B w, B x) {
return c2(rt_drop, w, x);
}
extern B rt_join;
B join_c1(B t, B x) {
if (isAtm(x)) thrM("∾: Argument must be an array");
@ -956,7 +955,7 @@ B join_c1(B t, B x) {
B rf; if(SFNS_FILLS) rf = getFillQ(x0);
ur r0 = isAtm(x0) ? 0 : rnk(x0);
usz ia = a(x)->ia;
usz xia = a(x)->ia;
usz* xsh = a(x)->sh;
usz tlen = 4*xr+2*r0; for (usz a=0; a<xr; a++) tlen+=xsh[a];
ShArr* sto = m_shArr(tlen); usz* st = sto->a; // Temp buffer
@ -965,16 +964,20 @@ B join_c1(B t, B x) {
// Length buffer i is lp+lp[i]
usz* lp = tsh+xr+r0; lp[0]=xr; for (usz a=1; a<xr; a++) lp[a] = lp[a-1]+xsh[a-1];
ur au = 0; // Number of root axes used so far
// Expand checked region from the root ⊑𝕩 along each axis in order,
// so that a non-root element is checked when the axis of the first
// nonzero in its index is reached.
ur tr = r0; // Number of root axes remaining
for (ur a = 0; a < xr; a++) {
// Check the axis starting at the root, getting axis lengths
usz n = xsh[a];
usz *ll = lp+lp[a];
if (n == 1) {
if (++au > r0) thrM("∾: Ranks of argument items too small");
continue;
if (!tr) thrM("∾: Ranks of argument items too small");
st[a] = ll[0] = a(x0)->sh[r0-tr];
tr--; continue;
}
usz step = st[a];
usz *ll = lp+lp[a];
ll[0] = r0;
for (usz i=1; i<n; i++) {
B c = GetU(x, i*step);
@ -982,8 +985,8 @@ B join_c1(B t, B x) {
}
usz r1s=r0; for (usz i=1; i<n; i++) if (ll[i]>r1s) r1s=ll[i];
ur r1 = r1s;
ur add = r1==r0;
if (au+add > r0) thrM("∾: Ranks of argument items too small");
ur a0 = r1==r0; // Root has axis a
if (tr < a0) thrM("∾: Ranks of argument items too small");
for (usz i=0; i<n; i++) {
ur rd = r1 - ll[i];
if (rd) {
@ -991,22 +994,22 @@ B join_c1(B t, B x) {
ll[i] = -1;
} else {
B c = GetU(x, i*step);
ll[i] = a(c)->sh[au];
ll[i] = a(c)->sh[r0-tr];
}
}
// Check shapes
for (usz j=0; j<ia; j+=n*step) {
for (usz j=0; j<xia; j+=n*step) {
B base = GetU(x, j);
ur r = isAtm(base) ? 0 : rnk(base);
ur r1 = r+1-add;
ur r1 = r+1-a0;
ur lr = 0;
if (r) {
usz* sh=a(base)->sh;
lr = r-(r0-au);
lr = r - tr;
shcpy(tsh,sh,r); shcpy(tsh0,sh,r);
if (!add) shcpy(tsh +lr+1, tsh +lr , r-lr );
else shcpy(tsh0+lr , tsh0+lr+1, r-lr-1);
if (!a0) shcpy(tsh +lr+1, tsh +lr , tr );
else shcpy(tsh0+lr , tsh0+lr+1, tr-1);
}
for (usz i=1; i<n; i++) {
B c = GetU(x, j+i*step);
@ -1018,18 +1021,67 @@ B join_c1(B t, B x) {
if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c));
}
}
au += add;
tr -= a0;
// Transform to lengths by changing -1 to 1, and get total
usz len = 0;
for (usz i=0; i<xsh[a]; i++) {
for (usz i=0; i<n; i++) {
len += ll[i] &= 1 | -(usz)(ll[i]!=-1);
}
st[a] = len;
}
// Move the data
usz* csh = tr ? a(x0)->sh + r0-tr : NULL; // Trailing shape
usz csz = shProd(csh, 0, tr);
MAKE_MUT(r, shProd(st, 0, xr)*csz);
// Element index and effective shape, updated progressively
usz *ei =tsh; for (usz i=0; i<xr; i++) ei [i]=0;
usz ri = 0;
usz *ll = lp+lp[xr-1];
for (usz i = 0;;) {
B e = GetU(x, i);
usz l = ll[ei[xr-1]] * csz;
if (RARE(isAtm(e))) {
assert(l==1);
mut_set(r, ri, inc(e));
} else {
usz eia = a(e)->ia;
if (eia) {
usz rj = ri;
usz *ii=tsh0; for (usz k=0; k<xr-1; k++) ii[k]=0;
usz str0 = st[xr-1]*csz;
for (usz j=0;;) {
mut_copy(r, rj, e, j, l);
j+=l; if (j==eia) break;
usz str = str0;
rj += str;
for (usz a = xr-2; RARE(++ii[a] == lp[lp[a]+ei[a]]); a--) {
rj -= ii[a]*str;
ii[a] = 0;
str *= st[a];
rj += str;
}
}
}
}
if (++i == xia) break;
ri += l;
usz str = csz;
for (usz a = xr-1; RARE(++ei[a] == xsh[a]); ) {
ei[a] = 0;
str *= st[a];
a--;
ri += (lp[lp[a]+ei[a]]-1) * str;
}
}
Arr* ra = mut_fp(r);
usz* sh = arr_shAlloc(ra, xr+tr);
shcpy(sh , st , xr);
shcpy(sh+xr, csh, tr);
decShObj(sto);
dec(rf);
return c1(rt_join, x);
decG(x);
return SFNS_FILLS? qWithFill(taga(ra), rf) : taga(ra);
}
}
B join_c2(B t, B w, B x) {