diff --git a/src/builtins/fold.c b/src/builtins/fold.c index 251b0683..d1049cca 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -376,6 +376,7 @@ u64 usum(B x) { // doesn't consume; will error on non-integers, or elements <0, B select_c1(B, B); B select_c2(B, B, B); +B shape_c1(B, B); static B m1c1(B t, B f, B x) { // consumes x B fn = m1_d(inc(t), inc(f)); B r = c1(fn, x); @@ -409,7 +410,19 @@ B insert_c1(Md1D* d, B x) { B f = d->f; } if (len==1) return C1(select, x); if (RARE(!isFun(f))) { decG(x); if (isMd(f)) thrM("Calling a modifier"); return inc(f); } - if (xr==1 && isPervasiveDyExt(f)) return m_unit(fold_c1(d, x)); + if (isPervasiveDyExt(f)) { + if (xr==1) return m_unit(fold_c1(d, x)); + if (len==IA(x)) { + B r = m_vec1(fold_c1(d, C1(shape, x))); + ur rr = xr - 1; + if (rr > 1) { + ShArr* rsh = m_shArr(rr); + PLAINLOOP for (ur i=0; ia[i] = 1; + arr_shReplace(a(r), rr, rsh); + } + return r; + } + } if (v(f)->flags) { u8 rtid = v(f)->flags-1; if (rtid==n_ltack) return C1(select, x); @@ -432,17 +445,32 @@ B insert_c1(Md1D* d, B x) { B f = d->f; return insert_base(f, x, 0, m_f64(0)); } B insert_c2(Md1D* d, B w, B x) { B f = d->f; - if (isAtm(x) || RNK(x)==0) thrM("˝: 𝕩 must have rank at least 1"); - if (*SH(x)==0) { decG(x); return w; } + ur xr; + if (isAtm(x) || (xr=RNK(x))==0) thrM("˝: 𝕩 must have rank at least 1"); + usz len = *SH(x); + if (len==0) { decG(x); return w; } if (RARE(!isFun(f))) { dec(w); decG(x); if (isMd(f)) thrM("Calling a modifier"); return inc(f); } - if (RNK(x)==1 && isPervasiveDyExt(f)) { - if (isAtm(w)) { - to_fold: return m_unit(fold_c2(d, w, x)); - } - if (RNK(w)==0) { + if (isPervasiveDyExt(f) && len==IA(x)) { + // 1-element arrays are always conformable + // final rank is higher of w, cell rank of x + ur rr = xr - 1; + if (isArr(w)) { + if (IA(w) != 1) goto skip; + ur wr = RNK(w); if (wr>rr) rr = wr; B w0=w; w = IGet(w,0); decG(w0); - goto to_fold; } + if (xr > 1) x = C1(shape, x); + B r = m_unit(fold_c2(d, w, x)); + if (rr > 0) { + if (rr == 1) arr_shVec(a(r)); + else { + ShArr* rsh = m_shArr(rr); + PLAINLOOP for (ur i=0; ia[i] = 1; + arr_shReplace(a(r), rr, rsh); + } + } + return r; + skip:; } if (v(f)->flags) { u8 rtid = v(f)->flags-1;