diff --git a/src/builtins/fold.c b/src/builtins/fold.c index c5af2770..36ee882f 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -1,6 +1,6 @@ -// Fold (´) +// Fold (´) and Insert (˝) -// Optimized operands: +// Fold optimized operands: // ⊣⊢ on all types // +-∧∨=≠ and synonyms on booleans // ≤<>≥ on booleans, monadic only, with a search @@ -13,6 +13,7 @@ #include "../core.h" #include "../builtins.h" +#include "../utils/calls.h" #include "../utils/mut.h" #if SINGELI_SIMD @@ -321,8 +322,63 @@ u64 usum(B x) { // doesn't consume; will error on non-integers, or elements <0, neg: thrM("Didn't expect negative integer"); } +static B m1c1(B t, B f, B x) { // consumes x + B fn = m1_d(inc(t), inc(f)); + B r = c1(fn, x); + decG(fn); + return r; +} +extern B rt_insert; +// From md1.c +extern bool isPervasiveDyExt(B x); +extern B insert_base(B f, B x, usz xia, bool has_w, B w); + +B insert_c1(Md1D* d, B x) { B f = d->f; + if (isAtm(x) || RNK(x)==0) thrM("˝: 𝕩 must have rank at least 1"); + usz xia = IA(x); + if (xia==0) { SLOW2("!𝕎˝𝕩", f, x); return m1c1(rt_insert, f, x); } + if (isFun(f)) { + u8 rtid = v(f)->flags-1; + if (RNK(x)==1 && isPervasiveDyExt(f)) return m_atomUnit(fold_c1(d, x)); + if (rtid == n_join) { + ur xr = RNK(x); + if (xr==1) return x; + ShArr* rsh; + if (xr>2) { + rsh = m_shArr(xr-1); + usz* xsh = SH(x); + shcpy(rsh->a+1, xsh+2, xr-2); + rsh->a[0] = xsh[0] * xsh[1]; + } + Arr* r = TI(x,slice)(x, 0, IA(x)); + if (xr>2) arr_shSetU(r, xr-1, rsh); + else arr_shVec(r); + return taga(r); + } + } + return insert_base(f, x, xia, 0, bi_N); +} +B insert_c2(Md1D* d, B w, B x) { B f = d->f; + if (isAtm(x) || RNK(x)==0) thrM("˝: 𝕩 must have rank at least 1"); + usz xia = IA(x); + B r = w; + if (xia==0) { decG(x); return r; } + + if (isFun(f)) { + if (RNK(x)==1 && isPervasiveDyExt(f)) { + if (isAtm(w)) { + to_fold: return m_atomUnit(fold_c2(d, w, x)); + } + if (RNK(w)==0) { + B w0=w; w = IGet(w,0); decG(w0); + goto to_fold; + } + } + } + return insert_base(f, x, xia, 1, w); +} + // Arithmetic fold/insert on rows of flat rank-2 array x -B insert_c1(Md1D*, B); B transp_c1(B, B); B join_c2(B, B, B); B fold_rows(Md1D* fd, B x) { diff --git a/src/builtins/md1.c b/src/builtins/md1.c index 13f733ee..aac77bd0 100644 --- a/src/builtins/md1.c +++ b/src/builtins/md1.c @@ -161,19 +161,6 @@ B timed_c1(Md1D* d, B x) { B f = d->f; } -static B m1c1(B t, B f, B x) { // consumes x - B fn = m1_d(inc(t), inc(f)); - B r = c1(fn, x); - decG(fn); - return r; -} -static B m1c2(B t, B f, B w, B x) { // consumes w,x - B fn = m1_d(inc(t), inc(f)); - B r = c2(fn, w, x); - decG(fn); - return r; -} - #pragma GCC diagnostic push #ifdef __clang__ #pragma GCC diagnostic ignored "-Wsometimes-uninitialized" @@ -467,64 +454,15 @@ B cell_c2(Md1D* d, B w, B x) { B f = d->f; return bqn_merge(r); } -B fold_c1(Md1D* d, B x); -B fold_c2(Md1D* d, B w, B x); - -extern B rt_insert; -B insert_c1(Md1D* d, B x) { B f = d->f; - if (isAtm(x) || RNK(x)==0) thrM("˝: 𝕩 must have rank at least 1"); - usz xia = IA(x); - if (xia==0) { SLOW2("!𝕎˝𝕩", f, x); return m1c1(rt_insert, f, x); } - if (isFun(f)) { - u8 rtid = v(f)->flags-1; - if (RNK(x)==1 && isPervasiveDyExt(f)) return m_atomUnit(fold_c1(d, x)); - if (rtid == n_join) { - ur xr = RNK(x); - if (xr==1) return x; - ShArr* rsh; - if (xr>2) { - rsh = m_shArr(xr-1); - usz* xsh = SH(x); - shcpy(rsh->a+1, xsh+2, xr-2); - rsh->a[0] = xsh[0] * xsh[1]; - } - Arr* r = TI(x,slice)(x, 0, IA(x)); - if (xr>2) arr_shSetU(r, xr-1, rsh); - else arr_shVec(r); - return taga(r); - } - } - - S_SLICES(x) - usz p = xia-x_csz; - B r = SLICE(x, p); - while(p!=0) { - p-= x_csz; - r = c2(f, SLICE(x, p), r); - } - E_SLICES(x) - return r; -} -B insert_c2(Md1D* d, B w, B x) { B f = d->f; - if (isAtm(x) || RNK(x)==0) thrM("˝: 𝕩 must have rank at least 1"); - usz xia = IA(x); - B r = w; - if (xia==0) { decG(x); return r; } - - if (isFun(f)) { - if (RNK(x)==1 && isPervasiveDyExt(f)) { - if (isAtm(w)) { - to_fold: return m_atomUnit(fold_c2(d, w, x)); - } - if (RNK(w)==0) { - B w0=w; w = IGet(w,0); decG(w0); - goto to_fold; - } - } - } - +// Used by Insert in fold.c +B insert_base(B f, B x, usz xia, bool has_w, B w) { S_SLICES(x) usz p = xia; + B r = w; + if (!has_w) { + p -= x_csz; + r = SLICE(x, p); + } while(p!=0) { p-= x_csz; r = c2(f, SLICE(x, p), r);