High-rank join shape checking

This commit is contained in:
Marshall Lochbaum 2022-08-05 21:38:57 -04:00
parent c1d5ca5c29
commit 1d891388b5

View File

@ -894,7 +894,7 @@ B join_c1(B t, B x) {
usz cam = 1; // Result length
if (rm) {
esh = a(x0)->sh;
cam = esh[0];
cam = *esh++;
} else {
rr++;
}
@ -904,27 +904,28 @@ B join_c1(B t, B x) {
ur cr = isAtm(c) ? 0 : rnk(c);
if (cr == 0) {
if (rm > 1) thrF("∾: Item ranks in a list can differ by at most one (contained ranks %i and %i)", 0, rm);
rd=1; cam++;
rd=rm; cam++;
} else {
usz* csh = a(c)->sh;
ur cd = rm - cr;
if (RARE(cd > rd)) {
if ((ur)(cd+1-rd) > 2-rd) thrF("∾: Item ranks in a list can differ by at most one (contained ranks %i and %i)", rm-rd*(cr==rm), cr);
if (cr > rr) { // Previous elements were cells
if (cam != i*esh[0]) thrM("∾: Item trailing shapes must be equal");
esh--; rr=cr; cam=i;
esh--;
if (cam != i * *esh) thrM("∾: Item trailing shapes must be equal");
rr=cr; cam=i;
}
rm = cr>rm ? cr : rm;
rd = 1;
}
if (cr < rm) { csh--; cam++; } else { cam+=csh[0]; }
for (usz j = 1; j < cr; j++) if (csh[j]!=esh[j]) thrF("∾: Item trailing shapes must be equal (contained arrays with shapes %H and %H)", x0, c);
cam += cr < rm ? 1 : *csh++;
if (!eqShPart(csh, esh, cr-1)) thrF("∾: Item trailing shapes must be equal (contained arrays with shapes %H and %H)", x0, c);
}
if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c));
}
if (rm==0) thrM("∾: Some item rank must be equal or greater than rank of argument");
usz csz = shProd(esh, 1, rr);
usz csz = shProd(esh, 0, rr-1);
MAKE_MUT(r, cam*csz);
usz ri = 0;
for (usz i = 0; i < xia; i++) {
@ -943,13 +944,91 @@ B join_c1(B t, B x) {
usz* sh = arr_shAlloc(ra, rr);
if (sh) {
sh[0] = cam;
shcpy(sh+1, esh+1, rr-1);
shcpy(sh+1, esh, rr-1);
}
decG(x);
return SFNS_FILLS? qWithFill(taga(ra), rf) : taga(ra);
} else if (xr==0) {
return bqn_merge(x);
} else {
SGetU(x)
B x0 = GetU(x,0);
B rf; if(SFNS_FILLS) rf = getFillQ(x0);
ur r0 = isAtm(x0) ? 0 : rnk(x0);
usz ia = a(x)->ia;
usz* xsh = a(x)->sh;
usz tlen = 4*xr+2*r0; for (usz a=0; a<xr; a++) tlen+=xsh[a];
ShArr* sto = m_shArr(tlen); usz* st = sto->a; // Temp buffer
st[xr-1]=1; for (ur a=xr; a-->1; ) st[a-1] = st[a]*xsh[a]; // Stride
usz* tsh0 = st+xr; usz* tsh = tsh0+xr+r0; // Test shapes
// Length buffer i is lp+lp[i]
usz* lp = tsh+xr+r0; lp[0]=xr; for (usz a=1; a<xr; a++) lp[a] = lp[a-1]+xsh[a-1];
ur au = 0; // Number of root axes used so far
for (ur a = 0; a < xr; a++) {
// Check the axis starting at the root, getting axis lengths
usz n = xsh[a];
if (n == 1) {
if (++au > r0) thrM("∾: Ranks of argument items too small");
continue;
}
usz step = st[a];
usz *ll = lp+lp[a];
ll[0] = r0;
for (usz i=1; i<n; i++) {
B c = GetU(x, i*step);
ll[i] = LIKELY(isArr(c)) ? rnk(c) : 0;
}
usz r1s=r0; for (usz i=1; i<n; i++) if (ll[i]>r1s) r1s=ll[i];
ur r1 = r1s;
ur add = r1==r0;
if (au+add > r0) thrM("∾: Ranks of argument items too small");
for (usz i=0; i<n; i++) {
ur rd = r1 - ll[i];
if (rd) {
if (rd>1) thrF("∾: Item ranks along an axis can differ by at most one (contained ranks %i and %i along axis %i)", ll[i], r1, a);
ll[i] = -1;
} else {
B c = GetU(x, i*step);
ll[i] = a(c)->sh[au];
}
}
// Check shapes
for (usz j=0; j<ia; j+=n*step) {
B base = GetU(x, j);
ur r = isAtm(base) ? 0 : rnk(base);
ur r1 = r+1-add;
ur lr = 0;
if (r) {
usz* sh=a(base)->sh;
lr = r-(r0-au);
shcpy(tsh,sh,r); shcpy(tsh0,sh,r);
if (!add) shcpy(tsh +lr+1, tsh +lr , r-lr );
else shcpy(tsh0+lr , tsh0+lr+1, r-lr-1);
}
for (usz i=1; i<n; i++) {
B c = GetU(x, j+i*step);
bool rd = ll[i]==-1;
tsh[lr] = ll[i];
ur cr=0; usz* sh=NULL; if (!isAtm(c)) { cr=rnk(c); sh=a(c)->sh; }
if (cr != r1-rd) thrF("∾: Incompatible item ranks", base, c);
if (!eqShPart(rd?tsh0:tsh, sh, cr)) thrF("∾: Incompatible item shapes (contained arrays with shapes %H and %H along axis %i)", base, c, a);
if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c));
}
}
au += add;
// Transform to lengths by changing -1 to 1, and get total
usz len = 0;
for (usz i=0; i<xsh[a]; i++) {
len += ll[i] &= 1 | -(usz)(ll[i]!=-1);
}
st[a] = len;
}
decShObj(sto);
dec(rf);
return c1(rt_join, x);
}
}