From 1d891388b596c5ade82d76b616245447a292313f Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 5 Aug 2022 21:38:57 -0400 Subject: [PATCH] High-rank join shape checking --- src/builtins/sfns.c | 95 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 8 deletions(-) diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index 9bd5256a..4572a247 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -894,7 +894,7 @@ B join_c1(B t, B x) { usz cam = 1; // Result length if (rm) { esh = a(x0)->sh; - cam = esh[0]; + cam = *esh++; } else { rr++; } @@ -904,27 +904,28 @@ B join_c1(B t, B x) { ur cr = isAtm(c) ? 0 : rnk(c); if (cr == 0) { if (rm > 1) thrF("∾: Item ranks in a list can differ by at most one (contained ranks %i and %i)", 0, rm); - rd=1; cam++; + rd=rm; cam++; } else { usz* csh = a(c)->sh; ur cd = rm - cr; if (RARE(cd > rd)) { if ((ur)(cd+1-rd) > 2-rd) thrF("∾: Item ranks in a list can differ by at most one (contained ranks %i and %i)", rm-rd*(cr==rm), cr); if (cr > rr) { // Previous elements were cells - if (cam != i*esh[0]) thrM("∾: Item trailing shapes must be equal"); - esh--; rr=cr; cam=i; + esh--; + if (cam != i * *esh) thrM("∾: Item trailing shapes must be equal"); + rr=cr; cam=i; } rm = cr>rm ? cr : rm; rd = 1; } - if (cr < rm) { csh--; cam++; } else { cam+=csh[0]; } - for (usz j = 1; j < cr; j++) if (csh[j]!=esh[j]) thrF("∾: Item trailing shapes must be equal (contained arrays with shapes %H and %H)", x0, c); + cam += cr < rm ? 1 : *csh++; + if (!eqShPart(csh, esh, cr-1)) thrF("∾: Item trailing shapes must be equal (contained arrays with shapes %H and %H)", x0, c); } if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c)); } if (rm==0) thrM("∾: Some item rank must be equal or greater than rank of argument"); - usz csz = shProd(esh, 1, rr); + usz csz = shProd(esh, 0, rr-1); MAKE_MUT(r, cam*csz); usz ri = 0; for (usz i = 0; i < xia; i++) { @@ -943,13 +944,91 @@ B join_c1(B t, B x) { usz* sh = arr_shAlloc(ra, rr); if (sh) { sh[0] = cam; - shcpy(sh+1, esh+1, rr-1); + shcpy(sh+1, esh, rr-1); } decG(x); return SFNS_FILLS? qWithFill(taga(ra), rf) : taga(ra); } else if (xr==0) { return bqn_merge(x); } else { + SGetU(x) + B x0 = GetU(x,0); + B rf; if(SFNS_FILLS) rf = getFillQ(x0); + ur r0 = isAtm(x0) ? 0 : rnk(x0); + + usz ia = a(x)->ia; + usz* xsh = a(x)->sh; + usz tlen = 4*xr+2*r0; for (usz a=0; aa; // Temp buffer + st[xr-1]=1; for (ur a=xr; a-->1; ) st[a-1] = st[a]*xsh[a]; // Stride + usz* tsh0 = st+xr; usz* tsh = tsh0+xr+r0; // Test shapes + // Length buffer i is lp+lp[i] + usz* lp = tsh+xr+r0; lp[0]=xr; for (usz a=1; a r0) thrM("∾: Ranks of argument items too small"); + continue; + } + usz step = st[a]; + usz *ll = lp+lp[a]; + ll[0] = r0; + for (usz i=1; ir1s) r1s=ll[i]; + ur r1 = r1s; + ur add = r1==r0; + if (au+add > r0) thrM("∾: Ranks of argument items too small"); + for (usz i=0; i1) thrF("∾: Item ranks along an axis can differ by at most one (contained ranks %i and %i along axis %i)", ll[i], r1, a); + ll[i] = -1; + } else { + B c = GetU(x, i*step); + ll[i] = a(c)->sh[au]; + } + } + + // Check shapes + for (usz j=0; jsh; + lr = r-(r0-au); + shcpy(tsh,sh,r); shcpy(tsh0,sh,r); + if (!add) shcpy(tsh +lr+1, tsh +lr , r-lr ); + else shcpy(tsh0+lr , tsh0+lr+1, r-lr-1); + } + for (usz i=1; ish; } + if (cr != r1-rd) thrF("∾: Incompatible item ranks", base, c); + if (!eqShPart(rd?tsh0:tsh, sh, cr)) thrF("∾: Incompatible item shapes (contained arrays with shapes %H and %H along axis %i)", base, c, a); + if (SFNS_FILLS && !noFill(rf)) rf = fill_or(rf, getFillQ(c)); + } + } + au += add; + + // Transform to lengths by changing -1 to 1, and get total + usz len = 0; + for (usz i=0; i