Shape logic for Join of any list

This commit is contained in:
Marshall Lochbaum 2022-08-01 17:27:22 -04:00
parent 88f65850fa
commit b8e6996e7e

View File

@ -883,36 +883,61 @@ B join_c1(B t, B x) {
B x0 = GetU(x,0); B x0 = GetU(x,0);
B rf; if(SFNS_FILLS) rf = getFillQ(x0); B rf; if(SFNS_FILLS) rf = getFillQ(x0);
if (isAtm(x0)) goto base; // thrM("∾: Rank of items must be equal or greater than rank of argument"); ur r0 = isAtm(x0) ? 0 : rnk(x0); // Minimum element rank seen
usz ir = rnk(x0); ur r1 = r0; // Maximum
usz* x0sh = a(x0)->sh; ur rr = r0; // Result rank, or minimum possible so far
if (ir==0) goto base; // thrM("∾: Rank of items must be equal or greater than rank of argument"); usz* esh = NULL;
usz cam = 1; // Result length
usz csz = arr_csz(x0); if (r0) {
usz cam = x0sh[0]; esh = a(x0)->sh;
cam = esh[0];
} else {
rr++;
}
for (usz i = 1; i < xia; i++) { for (usz i = 1; i < xia; i++) {
B c = GetU(x, i); B c = GetU(x, i);
if (!isArr(c) || rnk(c)!=ir) goto base; // thrF("∾: All items in argument should have same rank (contained items with ranks %i and %i)", ir, isArr(c)? rnk(c) : 0); ur cr = isAtm(c) ? 0 : rnk(c);
usz* csh = a(c)->sh; if (cr == 0) {
if (ir>1) for (usz j = 1; j < ir; j++) if (csh[j]!=x0sh[j]) thrF("∾: Item trailing shapes must be equal (contained arrays with shapes %H and %H)", x0, c); if (r1 > 1) thrF("∾: Item ranks in a list can differ by at most one (contained ranks %i and %i)", r0, r1);
cam+= a(c)->sh[0]; r0=0; cam++;
} else {
usz* csh = a(c)->sh;
if (cr != r0) {
if (cr > r1) r1 = cr; else r0 = cr;
if (r1-r0 > 2) thrF("∾: Item ranks in a list can differ by at most one (contained ranks %i and %i)", r0, r1);
}
if (cr < rr) {
csh--; cam++;
} else {
if (cr>rr) { // Previous elements were cells
if (cam != i*esh[0]) thrM("∾: Item trailing shapes must be equal");
esh--; rr++; cam = i;
}
cam+= csh[0];
}
for (usz j = 1; j < cr; j++) if (csh[j]!=esh[j]) thrF("∾: Item trailing shapes must be equal (contained arrays with shapes %H and %H)", x0, c);
}
if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c)); if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c));
} }
if (r1==0) thrM("∾: Some item rank must be equal or greater than rank of argument");
usz csz = shProd(esh, 1, rr);
MAKE_MUT(r, cam*csz); MAKE_MUT(r, cam*csz);
usz ri = 0; usz ri = 0;
for (usz i = 0; i < xia; i++) { for (usz i = 0; i < xia; i++) {
B c = GetU(x, i); B c = GetU(x, i);
if (isAtm(c)) goto base;
usz cia = a(c)->ia; usz cia = a(c)->ia;
mut_copy(r, ri, c, 0, cia); mut_copy(r, ri, c, 0, cia);
ri+= cia; ri+= cia;
} }
assert(ri==cam*csz); assert(ri==cam*csz);
Arr* ra = mut_fp(r); Arr* ra = mut_fp(r);
usz* sh = arr_shAlloc(ra, ir); usz* sh = arr_shAlloc(ra, rr);
if (sh) { if (sh) {
sh[0] = cam; sh[0] = cam;
shcpy(sh+1, x0sh+1, ir-1); shcpy(sh+1, esh+1, rr-1);
} }
decG(x); decG(x);
return SFNS_FILLS? qWithFill(taga(ra), rf) : taga(ra); return SFNS_FILLS? qWithFill(taga(ra), rf) : taga(ra);