From c0e3a3711f281fa7a4155367ec54852709c65bb6 Mon Sep 17 00:00:00 2001 From: dzaima Date: Thu, 12 Sep 2024 05:35:42 +0300 Subject: [PATCH] =?UTF-8?q?fast=20scalar=E2=89=8D=CB=98arr=20&=20arr?= =?UTF-8?q?=E2=89=8D=CB=98scalar=20&=20different-type=20arr=E2=89=8D=CB=98?= =?UTF-8?q?arr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/cells.c | 20 ++++++++++++++-- src/builtins/transpose.c | 51 ++++++++++++++++++++++++++++++++++------ test/cases/cells.bqn | 8 +++++++ 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index b01c71f9..f3920454 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -695,8 +695,17 @@ NOINLINE B for_cells_AS(B f, B w, B x, ur wcr, ur wr, u32 chr) { // F⟜x⎉wcr if (cam==0) return rank2_empty(f, w, wk, x, 0, chr); if (isFun(f)) { u8 rtid = v(f)->flags-1; - if (rtid==n_ltack) { dec(x); return w; } - if (rtid==n_rtack) return const_cells(w, wk, wsh, x, chr); + switch (rtid) { + case n_ltack: dec(x); return w; + case n_rtack: return const_cells(w, wk, wsh, x, chr); + case n_couple: if (RNK(w)==1) { + x = taga(arr_shVec(reshape_one(IA(w), x))); + B r = try_interleave_cells(w, x, 1, 1, wsh); + assert(!q_N(r)); + decG(w); decG(x); + return r; + } break; + } if (IA(w)!=0 && isPervasiveDy(f)) { if (isAtm(x)) return c2(f, w, x); if (RNK(x)!=wcr || !eqShPart(SH(x), wsh+wk, wcr)) goto generic; @@ -743,6 +752,13 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr return select_cells(WRAP(o2i64(w), l, thrF("⊏: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, false); } break; + case n_couple: if (RNK(x)==1) { + w = taga(arr_shVec(reshape_one(IA(x), w))); + B r = try_interleave_cells(w, x, 1, 1, xsh); + assert(!q_N(r)); + decG(w); decG(x); + return r; + } break; case n_pick: if (isF64(w) && xcr==1 && TI(x,arrD1)) { usz l = xsh[xk]; return select_cells(WRAP(o2i64(w), l, thrF("⊑: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, true); diff --git a/src/builtins/transpose.c b/src/builtins/transpose.c index 0d0ad7d3..fd0af17f 100644 --- a/src/builtins/transpose.c +++ b/src/builtins/transpose.c @@ -93,23 +93,59 @@ static void interleave_bits(u64* rp, void* x0v, void* x1v, usz n) { } } +B toBPtrAny(B x) { + if (arr_bptr(x)!=NULL) return x; + return taga(cpyHArr(x)); +} + +NOINLINE +B toElTypeArr(u8 re, B x) { // consumes x; returns an array with the given element type (re==el_B guarantees TO_BPTR working) + switch (re) { default: UD; + case el_bit: return toBitAny(x); + case el_i8: return toI8Any(x); + case el_i16: return toI16Any(x); + case el_i32: return toI32Any(x); + case el_f64: return toF64Any(x); + case el_c8: return toC8Any(x); + case el_c16: return toC16Any(x); + case el_c32: return toC32Any(x); + case el_B: return toBPtrAny(x); + } +} + + // Interleave arrays, 𝕨≍⎉(-xk)𝕩. Doesn't consume. -// Return bi_N if there isn't fast code. +// Assumes w and x have same shape. +// Return bi_N if there isn't fast code. Guaranteed to succeed on rank 1 w & x. B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh) { assert(RNK(w)==xr && xr>=1); - u8 xe = TI(x,elType); if (xe!=TI(w,elType)) return bi_N; + u8 we = TI(w,elType); + u8 xe = TI(x,elType); usz csz = shProd(xsh, xk, xr); if (csz & (csz-1)) return bi_N; // Not power of 2 - u8 xlw = elwBitLog(xe); + + u8 re = we==xe? we : el_or(we, xe); + if (0) { to_equal_types:; + // delay doing this until it's known that there will be code that can utilize it + incG(w); B w2 = re==we? w : toElTypeArr(re, w); + incG(x); B x2 = re==xe? x : toElTypeArr(re, x); + B r = try_interleave_cells(w2, x2, xr, xk, SH(x)); + assert(!q_N(r)); + decG(w2); decG(x2); + return r; + } + + u8 xlw = elwBitLog(re); usz n = shProd(xsh, 0, xk); usz ia = 2*n*csz; Arr *r; - if (csz==1 && xlw==0) { + if (csz==1 && xlw==0) { // we & xe are trivially el_bit u64* rp; r=m_bitarrp(&rp, ia); interleave_bits(rp, bitany_ptr(w), bitany_ptr(x), ia); } #if SINGELI - else if (csz==1 && xe==el_B) { + else if (csz==1 && re==el_B) { + if (we!=xe) goto to_equal_types; B* wp = TO_BPTR(w); B* xp = TO_BPTR(x); HArr_p p = m_harrUv(ia); // Debug build complains with harrUp si_interleave[3](p.a, wp, xp, n); @@ -119,10 +155,11 @@ B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh) { if (SFNS_FILLS) rb = qWithFill(rb, fill_both(w, x)); r = a(rb); } else if (csz<=64>>xlw && csz<=8) { // Require CPU-sized cells - assert(xe!=el_B); + if (we!=xe) goto to_equal_types; + assert(re!=el_B); void* rv; if (xlw==0) { u64* rp; r = m_bitarrp(&rp, ia); rv=rp; } - else rv = m_tyarrp(&r,elWidth(xe),ia,el2t(xe)); + else rv = m_tyarrp(&r,elWidth(re),ia,el2t(re)); si_interleave[CTZ(csz<