diff --git a/src/builtins/md2.c b/src/builtins/md2.c index 191e6f1e..b640c3e4 100644 --- a/src/builtins/md2.c +++ b/src/builtins/md2.c @@ -210,40 +210,45 @@ static B m2c2(B t, B f, B g, B w, B x) { // consumes w,x return r; } +static f64 req_whole(f64 f) { + if (floor(f)!=f) thrM("⎉: 𝕘 was a fractional number"); + return f; +} +static usz check_rank_vec(B g) { + if (!isArr(g)) thrM("⎉: Invalid 𝔾 result"); + usz gia = a(g)->ia; + if (!(gia>=1 && gia<=3)) thrM("⎉: 𝔾 result must have 1 to 3 elements"); + SGetU(g) + if (TI(g,elType)>=el_f64) for (i32 i = 0; i < gia; i++) req_whole(o2f(GetU(g,i))); + return gia; +} +static ur cell_rank(f64 r, f64 k) { // ⎉k over arg rank r + return k<0? (k+r<0? 0 : k+r) : (k>r? r : k); +} B rank_c1(Md2D* d, B x) { B f = d->f; B g = d->g; f64 kf; bool gf = isFun(g); if (RARE(gf)) g = c1(g, inc(x)); if (LIKELY(isNum(g))) { - kf = o2fu(g); - } else if (isArr(g)) { - usz gia = a(g)->ia; - if (!(gia>=1 && gia<=3)) thrM("⎉: 𝔾 result must have 1 to 3 elements"); - SGetU(g) - if (!elNum(TI(g,elType))) for (i32 i = 0; i < gia; i++) o2f(GetU(g,i)); - kf = GetU(g, gia==2).f; - } else thrM("⎉: Invalid 𝔾 result"); + kf = req_whole(o2fu(g)); + } else { + usz gia = check_rank_vec(g); + SGetU(g); kf = GetU(g, gia==2).f; + } if (gf) dec(g); - i32 k = kf; if (isAtm(x) || rnk(x)==0) { - if (floor(kf)!=kf) thrM("⎉: 𝕘 was a fractional number"); B r = c1(f, x); return isAtm(r)? m_atomUnit(r) : r; } i32 xr = rnk(x); - usz* xsh = a(x)->sh; - if (k!=kf) { - if (floor(kf)!=kf) thrM("⎉: 𝕘 was a fractional number"); - k = kf>0? 0 : xr; - } else { - k = k<0? (k+xr<0? xr : xr-(k+xr)) : (k>xr? 0 : xr-k); - } + ur cr = cell_rank(xr, kf); + i32 k = xr - cr; if (Q_BI(f,lt) && a(x)->ia!=0 && rnk(x)>1) return toKCells(x, k); + usz* xsh = a(x)->sh; usz cam = 1; for (usz i = 0; i < k; i++) cam*= xsh[i]; usz csz = 1; for (usz i = k; i < xr; i++) csz*= xsh[i]; - ur cr = xr-k; ShArr* csh; if (cr>1) { csh = m_shArr(cr); @@ -268,8 +273,133 @@ B rank_c1(Md2D* d, B x) { B f = d->f; B g = d->g; return bqn_merge(HARR_O(r).b); } extern B rt_rank; -B rank_c2(Md2D* d, B w, B x) { B f = d->f; B g = d->g; // TODO - return m2c2(rt_rank, f, g, w, x); +B rank_c2(Md2D* d, B w, B x) { B f = d->f; B g = d->g; + f64 wf, xf; + bool gf = isFun(g); + if (RARE(gf)) g = c2(g, inc(w), inc(x)); + if (LIKELY(isNum(g))) { + wf = xf = req_whole(o2fu(g)); + } else { + usz gia = check_rank_vec(g); + SGetU(g); + wf = GetU(g, gia<2?0:gia-2).f; + xf = GetU(g, gia-1).f; + } + + ur wr = isAtm(w) ? 0 : rnk(w); ur wc = cell_rank(wr, wf); + ur xr = isAtm(x) ? 0 : rnk(x); ur xc = cell_rank(xr, xf); + + B r; + if (wr == wc) { + if (xr == xc) { + if (gf) dec(g); + r = c2(f, w, x); + return isAtm(r)? m_atomUnit(r) : r; + } else { + i32 k = xr - xc; + usz* xsh = a(x)->sh; + usz cam = 1; for (usz i = 0; i < k; i++) cam*= xsh[i]; + usz csz = 1; for (usz i = k; i < xr; i++) csz*= xsh[i]; + if (cam == 0) { return m2c2(rt_rank, f, g, w, x); } // TODO + ShArr* csh; + if (xc>1) { csh=m_shArr(xc); shcpy(csh->a, xsh+k, xc); } + + BSS2A slice = TI(x,slice); + M_HARR(r, cam); + usz p = 0; + for (usz i = 0; i < cam; i++) { + Arr* s = slice(inc(x), p, csz); arr_shSetI(s, xc, csh); + HARR_ADD(r, i, c2(f, inc(w), taga(s))); + p+= csz; + } + + if (xc>1) ptr_dec(csh); + usz* rsh = HARR_FA(r, k); + if (k>1) shcpy(rsh, xsh, k); + + dec(w); decG(x); r = HARR_O(r).b; + } + } else if (xr == xc) { + i32 k = wr - wc; + usz* wsh = a(w)->sh; + usz cam = 1; for (usz i = 0; i < k; i++) cam*= wsh[i]; + usz csz = 1; for (usz i = k; i < wr; i++) csz*= wsh[i]; + if (cam == 0) { return m2c2(rt_rank, f, g, w, x); } // TODO + ShArr* csh; + if (wc>1) { csh=m_shArr(wc); shcpy(csh->a, wsh+k, wc); } + + BSS2A slice = TI(w,slice); + M_HARR(r, cam); + usz p = 0; + for (usz i = 0; i < cam; i++) { + Arr* s = slice(inc(w), p, csz); arr_shSetI(s, wc, csh); + HARR_ADD(r, i, c2(f, taga(s), inc(x))); + p+= csz; + } + + if (wc>1) ptr_dec(csh); + usz* rsh = HARR_FA(r, k); + if (k>1) shcpy(rsh, wsh, k); + + decG(w); dec(x); r = HARR_O(r).b; + } else { + i32 wk = wr - wc; usz* wsh = a(w)->sh; + i32 xk = xr - xc; usz* xsh = a(x)->sh; + i32 k=wk, zk=xk; if (k>zk) { i32 t=k; k=zk; zk=t; } + usz* zsh = wk>xk? wsh : xsh; + + usz cam = 1; for (usz i = 0; i < k; i++) { + usz wl = wsh[i], xl = xsh[i]; + if (wl != xl) thrF("⎉: Argument frames don't agree (%H ≡ ≢𝕨, %H ≡ ≢𝕩, common frame of %s axes)", w, x, k); + cam*= wsh[i]; + } + usz ext = 1; for (usz i = k; i < zk; i++) ext*= zsh[i]; + usz wsz = 1; for (usz i = wk; i < wr; i++) wsz*= wsh[i]; + usz xsz = 1; for (usz i = xk; i < xr; i++) xsz*= xsh[i]; + cam *= ext; + if (cam == 0) { return m2c2(rt_rank, f, g, w, x); } // TODO + + ShArr* wcs; if (wc>1) { wcs=m_shArr(wc); shcpy(wcs->a, wsh+wk, wc); } + ShArr* xcs; if (xc>1) { xcs=m_shArr(xc); shcpy(xcs->a, xsh+xk, xc); } + + BSS2A wslice = TI(w,slice); + BSS2A xslice = TI(x,slice); + M_HARR(r, cam); + usz wp = 0, xp = 0; + #define CELL(wx) \ + Arr* wx##s = wx##slice(inc(wx), wx##p, wx##sz); \ + arr_shSetI(wx##s, wx##c, wx##cs); \ + wx##p+= wx##sz + #define F(W,X) HARR_ADD(r, i, c2(f, W, X)) + if (ext == 1) { + for (usz i = 0; i < cam; i++) { + CELL(w); CELL(x); F(taga(ws), taga(xs)); + } + } else if (wk < xk) { + for (usz i = 0; i < cam; ) { + CELL(w); B wb=taga(ws); + for (usz e = i+ext; i < e; i++) { CELL(x); F(inc(wb), taga(xs)); } + dec(wb); + } + } else { + for (usz i = 0; i < cam; ) { + CELL(x); B xb=taga(xs); + for (usz e = i+ext; i < e; i++) { CELL(w); F(taga(ws), inc(xb)); } + dec(xb); + } + } + #undef CELL + #undef F + + if (wc>1) ptr_dec(wcs); + if (xc>1) ptr_dec(xcs); + usz* rsh = HARR_FA(r, zk); + if (zk>1) shcpy(rsh, zsh, zk); + + decG(w); decG(x); r = HARR_O(r).b; + } + if (gf) dec(g); + return bqn_merge(r); } diff --git a/src/core/harr.c b/src/core/harr.c index 150ed7d9..662291b7 100644 --- a/src/core/harr.c +++ b/src/core/harr.c @@ -8,8 +8,8 @@ B toCells(B x) { usz cam = a(x)->sh[0]; usz csz = arr_csz(x); BSS2A slice = TI(x,slice); - usz p = 0; M_HARR(r, cam) + usz p = 0; if (rnk(x)==2) { for (usz i = 0; i < cam; i++) { Arr* s = slice(inc(x), p, csz); arr_shVec(s); @@ -20,7 +20,7 @@ B toCells(B x) { usz cr = rnk(x)-1; ShArr* csh = m_shArr(cr); usz* xsh = a(x)->sh; - for (u64 i = 0; i < cr; i++) csh->a[i] = xsh[i+1]; + shcpy(csh->a, xsh+1, cr); for (usz i = 0; i < cam; i++) { Arr* s = slice(inc(x), p, csz); arr_shSetI(s, cr, csh); HARR_ADD(r, i, taga(s)); @@ -41,12 +41,12 @@ B toKCells(B x, ur k) { ShArr* csh; if (cr>1) { csh = m_shArr(cr); - for (i32 i = 0; i < cr; i++) csh->a[i] = xsh[i+k]; + shcpy(csh->a, xsh+k, cr); } - usz p = 0; - M_HARR(r, cam) BSS2A slice = TI(x,slice); + M_HARR(r, cam); + usz p = 0; for (usz i = 0; i < cam; i++) { Arr* s = slice(inc(x), p, csz); arr_shSetI(s, cr, csh); HARR_ADD(r, i, taga(s)); @@ -54,7 +54,7 @@ B toKCells(B x, ur k) { } if (cr>1) ptr_dec(csh); usz* rsh = HARR_FA(r, k); - if (rsh) for (i32 i = 0; i < k; i++) rsh[i] = xsh[i]; + if (rsh) shcpy(rsh, xsh, k); decG(x); return HARR_O(r).b; }