From c90674313b233bb1ce4ba30e7e9e9546f3060981 Mon Sep 17 00:00:00 2001 From: dzaima Date: Sun, 8 Jun 2025 05:49:55 +0300 Subject: [PATCH] =?UTF-8?q?reduce=20reshape+replicate-based=20Arith?= =?UTF-8?q?=E2=8C=9C=20constant=20overhead?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/md1.c | 24 ++++++++++++++++++------ src/builtins/sfns.c | 2 +- test/cases/prims.bqn | 2 ++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/builtins/md1.c b/src/builtins/md1.c index 6553a369..cc8ffd92 100644 --- a/src/builtins/md1.c +++ b/src/builtins/md1.c @@ -64,6 +64,8 @@ B tbl_c1(Md1D* d, B x) { B slash_c2(B t, B w, B x); B shape_c2(B t, B w, B x); +Arr* reshape_cycle(usz nia, usz xia, B x); // from sfns.c +static B replicate_by(usz rep, usz xia, B x) { return C2(slash, m_usz(rep), taga(arr_shVec(TI(x,slice)(incG(x), 0, xia)))); } B tbl_c2(Md1D* d, B w, B x) { B f = d->f; if (isAtm(w)) w = m_unit(w); if (isAtm(x)) x = m_unit(x); @@ -84,20 +86,30 @@ B tbl_c2(Md1D* d, B w, B x) { B f = d->f; rsh = arr_shAlloc(ra, rr); r = taga(ra); } else if (RTID(f) == n_ltack) { - Arr* wd = arr_shVec(TI(w,slice)(incG(w), 0, wia)); - r = C2(slash, m_i32(xia), taga(wd)); + r = replicate_by(xia, wia, w); goto arith_finish; } else if (RTID(f) == n_rtack) { r = C2(shape, m_f64(ria), incG(x)); goto arith_finish; - } else if (TI(w,arrD1) && isPervasiveDyExt(f)) { + } else if (isPervasiveDyExt(f)) { + if (ria == 0) goto arith_empty; + if (!TI(w,arrD1)) goto generic; if (TI(x,arrD1) && wia>=4 && xia<2560>>arrTypeBitsLog(TY(x))) { - Arr* wd = arr_shVec(TI(w,slice)(incG(w), 0, wia)); - r = fc2(f, C2(slash, m_i32(xia), taga(wd)), C2(shape, m_f64(ria), incG(x))); + B expW, expX; + if (0) { + arith_empty:; + expW = taga(emptyArr(w, 1)); + expX = taga(emptyArr(x, 1)); + } else { + assert(wia>1); // implies ria > xia, a requirement of reshape_cycle + expW = replicate_by(xia, wia, w); + expX = taga(arr_shVec(reshape_cycle(ria, xia, incG(x)))); + } + r = fc2(f, expW, expX); arith_finish:; if(RARE(!reusable(r))) r = taga(cpyWithShape(r)); arr_shErase(a(r), 1); - } else if (xia>7 && wia>0) { + } else if (xia>7) { SGet(w) M_APD_TOT(rm, ria) incByG(x, wia); diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index 444fafd3..a0dee4be 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -312,7 +312,7 @@ NOINLINE B shape_c2_listw(B t, B w, B x) { return taga(arr_shSetUO(reshape_unshaped(nia, x), nr, sh)); } -Arr* reshape_cycle(usz nia, usz xia, B x) { +Arr* reshape_cycle(usz nia, usz xia, B x) { // used directly by tbl_c2 assert(nia > xia); Arr* r; if (xia <= 1) { diff --git a/test/cases/prims.bqn b/test/cases/prims.bqn index da78a5cb..1841554c 100644 --- a/test/cases/prims.bqn +++ b/test/cases/prims.bqn @@ -606,6 +606,8 @@ b←1↓1∾a←"hello" ⋄ b ⌽⎊'e' ⥊⟜1⍟2 5 ⋄ a ≡○•Hash b %% 1 # ⌜ !"𝕨𝔽⌜𝕩: Result rank too large (200≡=𝕨, 200≡=𝕩)" % +⌜˜(200⥊1)⥊1 +%USE eqvar ⋄ (! + ○•internal.Keep⌜ ≡ + ⌜)○↕⌜˜ ↕6 # TODO _eqvar once it's consistent +%USE eqvar ⋄ (! -˜○•internal.Keep⌜ ≡ -˜⌜)○↕⌜˜ ↕6 # TODO _eqvar once it's consistent # ˜ 5˜˝ "ab" %% 5