From b10a87fe001151950ffe894775bc247125c0abbf Mon Sep 17 00:00:00 2001 From: dzaima Date: Sat, 22 Apr 2023 01:04:07 +0300 Subject: [PATCH] =?UTF-8?q?merge=20=CB=98=20&=20=E2=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/cells.c | 221 +++++++++++++++++++++---------------------- 1 file changed, 110 insertions(+), 111 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index b1a279b4..5cf93ee5 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -211,13 +211,13 @@ static B transp_cells(ur ax, B x) { // helpers -static NOINLINE B to_fill_cell_impl(B x, ur k, char* err) { // consumes x +static NOINLINE B to_fill_cell(B x, ur k, u32 chr) { // consumes x B xf = getFillQ(x); if (noFill(xf)) xf = m_f64(0); ur cr = RNK(x)-k; usz* sh = SH(x)+k; usz csz = 1; - for (usz i=0; iUR_MAX) thrF("%c: Result rank too large", chr); Arr* r = m_fillarrpEmpty(getFillQ(rc)); @@ -245,16 +239,16 @@ FORCE_INLINE B merge_fill_result_impl(u32 chr, B rc, ur k, usz* sh) { dec(rc); return taga(r); } -static NOINLINE B merge_fill_result_k(B rc, ur k, usz* sh) { - return merge_fill_result_impl(U'⎉', rc, k, sh); +static B merge_fill_result_k(B rc, ur k, usz* sh) { + return merge_fill_result_impl(rc, k, sh, U'⎉'); } -static NOINLINE B merge_fill_result_1(B rc) { - return merge_fill_result_impl(U'˘', rc, 1, (usz[]){0}); +static B merge_fill_result_1(B rc) { + return merge_fill_result_impl(rc, 1, (usz[]){0}, U'˘'); } static NOINLINE B cell2_empty(B f, B w, B x, ur wr, ur xr) { if (!isPureFn(f) || !CATCH_ERRORS) { dec(w); dec(x); return emptyHVec(); } - if (wr) w = to_fill_cell_1(w); - if (xr) x = to_fill_cell_1(x); + if (wr) w = to_fill_cell(w, 1, U'˘'); + if (xr) x = to_fill_cell(x, 1, U'˘'); if (CATCH) { freeThrown(); return emptyHVec(); } B rc = c2(f, w, x); popCatch(); @@ -281,115 +275,93 @@ static ur cell_rank(f64 r, f64 k) { // ⎉k over arg rank r // monadic ˘ & ⎉ -B cell_c1(Md1D* d, B x) { B f = d->f; - if (isAtm(x) || RNK(x)==0) { - B r = c1(f, x); - return isAtm(r)? m_atomUnit(r) : r; - } - +B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with 0≤cr≤xr, array x, and xr>0 + assert(isArr(x) && xr>0); + usz* xsh = SH(x); + usz cam = shProd(xsh, 0, k); if (isFun(f)) { - if (IA(x)!=0) { - u8 rtid = v(f)->flags-1; - ur xr = RNK(x); - if (rtid==n_lt && xr>1) return toCells(x); - if (rtid==n_select && xr>1) return select_cells(0, x, xr); - if (rtid==n_pick && xr>1 && TI(x,arrD1)) return select_cells(0, x, xr); - if (rtid==n_couple) { - if (xr==0) return C1(shape, x); - Arr* r = cpyWithShape(x); - usz* xsh = PSH(r); - if (xr==UR_MAX) thrF("≍˘: Result rank too large (%i≡=𝕩)", xr); + if (cam==0 || IA(x)==0) goto noSpecial; // TODO be more granular about this + u8 rtid = v(f)->flags-1; + switch(rtid) { + case n_lt: + return k==1 && RNK(x)>1? toCells(x) : k==0? m_atomUnit(x) : toKCells(x, k); + case n_select: + if (k!=1 || xr<=1) goto base; // TODO handle more ranks + selectCells:; + return select_cells(0, x, xr); + case n_pick: + if (k!=1 || cr!=1 || !TI(x,arrD1)) goto base; // TODO handle more ranks + goto selectCells; + case n_couple: { + Arr* r = cpyWithShape(x); xsh=PSH(r); + if (xr==UR_MAX) thrF("≍%c: Result rank too large (%i≡=𝕩)", chr, xr); ShArr* rsh = m_shArr(xr+1); - rsh->a[0] = xsh[0]; - rsh->a[1] = 1; - shcpy(rsh->a+2, xsh+1, xr-1); + shcpy(rsh->a, xsh, k); + rsh->a[k] = 1; + shcpy(rsh->a+k+1, xsh+k, xr-k); return taga(arr_shReplace(r, xr+1, rsh)); } - if (rtid==n_shape) { - if (xr==2) return x; - Arr* r = cpyWithShape(x); - usz cam = PSH(r)[0]; - usz csz = shProd(PSH(r), 1, xr); - ShArr* rsh = m_shArr(2); - rsh->a[0] = cam; - rsh->a[1] = csz; - return taga(arr_shReplace(r, 2, rsh)); + case n_shape: { + if (cr==1) return x; + if (k==0) return C1(shape, x); + Arr* r = cpyWithShape(x); xsh=PSH(r); + usz csz = shProd(xsh, k, xr); + ShArr* rsh = m_shArr(k+1); + shcpy(rsh->a, xsh, k); + rsh->a[k] = csz; + return taga(arr_shReplace(r, k+1, rsh)); } - if ((rtid==n_shifta || rtid==n_shiftb) && xr==2) { + case n_shifta: case n_shiftb: { + if (k!=1 || xr!=2) goto base; // TODO handle more ranks B xf = getFillR(x); - if (!noFill(xf)) return shift_cells(xf, x, TI(x,elType), rtid); + if (noFill(xf)) goto base; + return shift_cells(xf, x, TI(x,elType), rtid); } - if (rtid==n_transp) return xr<=2? x : transp_cells(xr-1, x); - if (TY(f) == t_md1D) { - Md1D* fd = c(Md1D,f); - u8 rtid = fd->m1->flags-1; - if (rtid==n_const) { f=fd->f; goto const_f; } - if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B && isPervasiveDyExt(fd->f) && RNK(x)==2) { - usz *sh = SH(x); usz m = sh[1]; - if (m == 1) return select_cells(0, x, 2); - if (m <= 64 && m < sh[0]) return fold_rows(fd, x); - } + case n_transp: { + if (k!=1) goto base; // TODO handle more ranks + return cr<=1? x : transp_cells(xr-1, x); + } + } + + noSpecial:; + if (TY(f) == t_md1D) { + Md1D* fd = c(Md1D,f); + u8 rtid = fd->m1->flags-1; + if (rtid==n_const) { f=fd->f; goto const_f; } + if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B && k==1 && xr==2 && isPervasiveDyExt(fd->f)) { + usz *sh = SH(x); usz m = sh[1]; + if (m == 1) return select_cells(0, x, 2); + if (m <= 64 && m < sh[0]) return fold_rows(fd, x); } } } else if (!isMd(f)) { - const_f:; - usz cam = SH(x)[0]; - decG(x); - B fv = inc(f); - if (isAtm(fv)) return C2(shape, m_f64(cam), fv); - usz vr = RNK(fv); - if (vr==UR_MAX) thrM("˘: Result rank too large"); - f64* shp; B sh = m_f64arrv(&shp, vr+1); - shp[0] = cam; - usz* fsh = SH(fv); - PLAINLOOP for (usz i = 0; i < vr; i++) shp[i+1] = fsh[i]; - return C2(shape, sh, fv); + const_f:; inc(f); + u32 fr; + if (isAtm(f) || RNK(f)==0) { + if (k!=1) { fr = 0; goto const_f_cont; } + usz cam = xsh[0]; + decG(x); + return C2(shape, m_usz(cam), f); + } else { + fr = RNK(f); + if (fr+k > UR_MAX) thrF("%c: Result rank too large", chr); + const_f_cont:; + f64* shp; B sh = m_f64arrv(&shp, fr+k); + PLAINLOOP for (usz i=0; if; B g = d->g; - f64 kf; - bool gf = isFun(g); - if (RARE(gf)) g = c1(g, inc(x)); - if (LIKELY(isNum(g))) { - kf = req_whole(o2fG(g)); - } else { - usz gia = check_rank_vec(g); - SGetU(g); kf = GetU(g, gia==2).f; - } - if (gf) dec(g); - - if (isAtm(x) || RNK(x)==0) { - B r = c1(f, x); - return isAtm(r)? m_atomUnit(r) : r; - } - i32 xr = RNK(x); - ur cr = cell_rank(xr, kf); - i32 k = xr - cr; - if (Q_BI(f,lt) && IA(x)!=0 && RNK(x)>1) return toKCells(x, k); - - usz* xsh = SH(x); - usz cam = shProd(xsh, 0, k); if (cam == 0) { usz s0=0; ShArr* s=NULL; if (xr<=1) { s0=xsh[0]; xsh=&s0; } else { s=ptr_inc(shObj(x)); } if (!isPureFn(f) || !CATCH_ERRORS) { decG(x); goto empty; } - B cf = to_fill_cell_k(x, k); + B cf = to_fill_cell(x, k, chr); B r; if (CATCH) { empty: freeThrown(); @@ -403,6 +375,7 @@ B rank_c1(Md2D* d, B x) { B f = d->f; B g = d->g; return r; } + base:; M_HARR(r, cam); S_KSLICES(x, xsh, k); for (usz i=0,p=0; if; B g = d->g; return bqn_merge(HARR_O(r).b); } - +B cell_c1(Md1D* d, B x) { B f = d->f; + ur xr; + if (isAtm(x) || (xr=RNK(x))==0) { + B r = c1(f, x); + return isAtm(r)? m_atomUnit(r) : r; + } + return for_cells_c1(f, xr, xr-1, 1, x, U'˘'); +} +B rank_c1(Md2D* d, B x) { B f = d->f; B g = d->g; + f64 kf; + bool gf = isFun(g); + if (RARE(gf)) g = c1(g, inc(x)); + if (LIKELY(isNum(g))) { + kf = req_whole(o2fG(g)); + } else { + usz gia = check_rank_vec(g); + SGetU(g); kf = GetU(g, gia==2).f; + } + if (gf) dec(g); + ur xr; + if (isAtm(x) || (xr=RNK(x))==0) { + B r = c1(f, x); + return isAtm(r)? m_atomUnit(r) : r; + } + ur cr = cell_rank(xr, kf); + return for_cells_c1(f, xr, cr, xr-cr, x, U'⎉'); +} @@ -481,8 +480,8 @@ static NOINLINE B rank2_empty(B f, B w, ur wk, B x, ur xk) { if (!sho) { s0=sh[0]; sh=&s0; } else { s=ptr_inc(shObj(fa)); } if (!isPureFn(f) || !CATCH_ERRORS) { dec(w); dec(x); goto empty; } B r; - if (wk) w = to_fill_cell_k(w, wk); - if (xk) x = to_fill_cell_k(x, xk); + if (wk) w = to_fill_cell(w, wk, U'⎉'); + if (xk) x = to_fill_cell(x, xk, U'⎉'); if (CATCH) { empty: freeThrown(); r = empty_frame(sh, k);