From 13906efe449f30eb49f912ea5c2ee82345e75613 Mon Sep 17 00:00:00 2001 From: dzaima Date: Fri, 2 May 2025 01:11:38 +0300 Subject: [PATCH] =?UTF-8?q?fix=20=E2=8A=91=CB=98=20on=20rank>2=20inputs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/cells.c | 44 +++++++++++++++++++++++++------------------ src/builtins/select.c | 19 +++++++++---------- src/builtins/sfns.c | 2 +- test/cases/cells.bqn | 13 ++++++++++++- 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index 1378e36c..9c1ffc20 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -234,25 +234,33 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { // // fast special-case implementations -B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf); // from select.c -static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x - ur xr = RNK(x); - assert(xr>1 && k1 && k1 && k<=xr); + usz* xsh = SH(x); + usz l = shProd(xsh, k, xr); + assert(ind < (k==xr? 1 : xsh[k]) && cam*l == IA(x)); + B r = select_cells_single(ind, x, cam, l, 1); + usz* rsh = arr_shAlloc(a(r), k); + if (rsh) shcpy(rsh, xsh, k); + decG(x); + return r; +} static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores assert(p < stride); @@ -441,11 +449,11 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x; array x, case n_select: if (IA(x)==0) goto noSpecial; if (cr==0) goto base; - return select_cells(0, x, cam, k, false); + return select_cells(0, x, xr, cam, k); case n_pick: if (IA(x)==0) goto noSpecial; - if (cr==0 || !TI(x,arrD1)) goto base; - return select_cells(0, x, cam, k, true); + if (!TI(x,arrD1)) goto base; + return pick_cells(0, x, xr, cam, k); case n_couple: { Arr* r = cpyWithShape(x); xsh=PSH(r); if (xr==UR_MAX) thrF("≍%U 𝕩: Result rank too large (%i≡=𝕩)", chr==U'˘'? "˘" : "⎉𝕘", xr); @@ -524,8 +532,8 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x; array x, usz m = xsh[k]; if (m==0) return insert_cells_identity(x, fd->f, xsh, xr, k, rtid); if (TI(x,elType)==el_B) break; - if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false); - if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false); + if (m==1 || frtid==n_ltack) return select_cells(0 , x, xr, cam, k); + if ( frtid==n_rtack) return select_cells(m-1, x, xr, cam, k); if (isPervasiveDyExt(fd->f) && 1==shProd(xsh, k+1, xr)) { B r; // special cases always return rank 1 @@ -729,7 +737,7 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr if (wr == 0) { usz ind = WRAP(o2i64(IGetU(w,0)), xsh[xk], break); decG(w); - return select_cells(ind, x, cam, xk, false); + return select_cells(ind, x, xr, cam, xk); } ur rr = xk+wr; ShArr* rsh = m_shArr(rr); @@ -740,7 +748,7 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr } if (isF64(w) && xcr>=1) { usz l = xsh[xk]; - return select_cells(WRAP(o2i64(w), l, thrF("𝕨⊏𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, false); + return select_cells(WRAP(o2i64(w), l, thrF("𝕨⊏𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, xr, cam, xk); } break; case n_couple: if (RNK(x)==1) { @@ -750,7 +758,7 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr } break; case n_pick: if (isF64(w) && xcr==1 && TI(x,arrD1)) { usz l = xsh[xk]; - return select_cells(WRAP(o2i64(w), l, thrF("𝕨⊑𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, true); + return pick_cells(WRAP(o2i64(w), l, thrF("𝕨⊑𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, xr, cam, xk); } break; case n_shifta: case n_shiftb: if (isAtm(w)) { if (IA(x)==0) return x; diff --git a/src/builtins/select.c b/src/builtins/select.c index db080f66..08ea259c 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -616,30 +616,29 @@ B select_cells_base(B inds, B x0, ux csz, ux cam); extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli) extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8); -B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf) { // ⥊ ind {leaf? <∘⊑; ⊏}˘ cam‿l‿csz ⥊ x - usz take = leaf? 1 : csz; +B select_cells_single(usz ind, B x, usz cam, usz l, usz csz) { // ⥊ ind ⊏˘ cam‿l‿csz ⥊ x Arr* ra; - if (l==1 && take==csz) { + if (l==1) { ra = cpyWithShape(incG(x)); arr_shErase(ra, 1); } else { u8 xe = TI(x,elType); u8 ewl= elwBitLog(xe); - u8 xl = leaf? ewl : multWidthLog(csz, ewl); - usz ria = cam*take; + u8 xl = multWidthLog(csz, ewl); + usz ria = cam*csz; if (xl>=7 || (xl<3 && xl>0)) { // generic case MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm); usz jump = l * csz; - usz xi = take*ind; + usz xi = csz*ind; usz ri = 0; for (usz i = 0; i < cam; i++) { - mut_copyG(rm, ri, x, xi, take); + mut_copyG(rm, ri, x, xi, csz); xi+= jump; - ri+= take; + ri+= csz; } ra = mut_fp(rm); } else if (xe==el_B) { - assert(take == 1); + assert(csz == 1); SGet(x) HArr_p rp = m_harrUv(ria); for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind); @@ -948,7 +947,7 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸ if (in == 0) return taga(emptyArr(x, 1)); if (in == 1) { B w = IGetU(inds,0); if (!isF64(w)) goto generic; - B r = select_cells_single(WRAP_SELECT_ONE(o2i64(w), csz, "%R", w), x, cam, csz, 1, false); + B r = select_cells_single(WRAP_SELECT_ONE(o2i64(w), csz, "%R", w), x, cam, csz, 1); decG(x); decG(inds); return r; } u8 ie = TI(inds,elType); diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index a5871c49..26237733 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -605,7 +605,7 @@ NOINLINE B takedrop_highrank(bool take, B w, B x) { if (ri!=pri) mut_fillG(rm, pri, xf, ri-pri); pri = ri+cellWrite; } - mut_copyG(rm, ri, x, xi, cellWrite); + mut_copyG(rm, ri, x, xi, cellWrite); // TODO could use cf_ usz cr = cellStart-1; if (0 == --lcv[cr]) { do { diff --git a/test/cases/cells.bqn b/test/cases/cells.bqn index 1e2c62ea..c2b04e53 100644 --- a/test/cases/cells.bqn +++ b/test/cases/cells.bqn @@ -41,7 +41,9 @@ !"𝕨⊏𝕩: 𝕩 cannot be a unit" % 0⊏˘5⥊<"a" !"𝕨⊏𝕩: 𝕩 cannot be a unit" % (3‿0⥊⟨⟩)⊏˘↕3 !"Expected integer, got 0.1" % 0.1⊑˘3‿5⥊↕15 -!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 5⊑˘↕2‿3‿4 +!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 5 ⊑˘ ↕2‿3‿4 +!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 5 ⊑˘ 2‿3‿4⥊2 +!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 0 ⊑⎉2 2‿3‿4⥊2 !">𝕩: Result rank too large (80 ≡ =𝕩, 205 ≡ =⊑𝕩)" % >⎉80 (200⥊1)⥊<(205⥊1)⥊1 !"𝔽⎉𝕘: Result rank too large (195 ≡ =𝕩, 210 ≡ =𝔽v)" % >⎉5 (200⥊1)⥊<(205⥊1)⥊1 !"𝕨∾𝕩: Lengths not matchable (⟨6⟩ ≡ ≢𝕨, 1‿1 ≡ ≢𝕩)" % ("abc"∾"def")∾˘○(3/≍)≍"a" @@ -82,6 +84,15 @@ %USE tcc ⋄ ⊏_tcc_ ¯1 4‿2⥊↕8 ⋄ 1⊏_tcc_ ¯1 4‿2⥊↕8 ⋄ ⊏_tcc_ ¯1 ↕4‿2 ⋄ 1⊏_tcc_ ¯1 ↕4‿2 %USE tcc ⋄ ⊑_tcc_ ¯1 4‿2⥊↕8 ⋄ 1⊑_tcc_ ¯1 4‿2⥊↕8 ⋄ ⊑_tcc_ ¯1 ↕4‿2 ⋄ 1⊑_tcc_ ¯1 ↕4‿2 +1⊸⊏˘ 0‿2‿3‿4⥊2 %% 0‿3‿4⥊0 %!HEAPVERIFY +2⊸⊏˘ 0‿2‿2‿2⥊2 %% ↕0 +⊑⎉¯1 2‿3‿4⥊↕24 %% 2⥊3×4×↕2 +⊑⎉¯2 2‿3‿4⥊↕24 %% 2‿3⥊4×↕6 +⊑⎉¯3 2‿3‿4⥊↕24 %% 2‿3‿4⥊↕24 +0⊸⊑⎉1 2‿3‿4⥊↕24 %% 2‿3⥊4×↕6 +1⊸⊑⎉1 2‿3‿4⥊↕24 %% 1+2‿3⥊4×↕6 +⊑⎉0 10‿10⥊↕100 %% 10‿10⥊↕100 + %USE eqvar ⋄ 0‿0‿0‿0 {𝕨⊸⊏˘𝕩}_eqvar ≍˘↕5 %% 5‿4⥊⌊4÷˜↕20 %USE eqvar ⋄ 0‿¯1‿0‿¯1 {𝕨⊸⊏˘𝕩}_eqvar ≍˘↕5 %% 5‿4⥊⌊4÷˜↕20 %USE eqvar ⋄ 100‿¯3 {𝕨⊸⊏˘𝕩}_eqvar 10‿200⥊↕2000 %% 10‿2⥊100‿197‿300‿397‿500‿597‿700‿797‿900‿997‿1100‿1197‿1300‿1397‿1500‿1597‿1700‿1797‿1900‿1997