From 8766810db81290ae07db1d612d113e2dcd198273 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Sat, 22 Jun 2024 09:39:02 -0400 Subject: [PATCH] Rank-agnostic select_cells --- src/builtins/cells.c | 66 ++++++++++++++++++++++++-------------------- src/core/arrFns.h | 14 ++++++---- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index 68d36f30..2884cae8 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -178,52 +178,58 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { // // fast special-case implementations extern void (*const si_select_cells_bit_lt64)(uint64_t*,uint64_t*,uint32_t,uint32_t,uint32_t); // from fold.c (fold.singeli) -static NOINLINE B select_cells(usz n, B x, usz cam, usz k, bool leaf) { // n {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏? +static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏? ur xr = RNK(x); assert(xr>1 && k=7 || (xl<3 && xl>0)) { // generic case + MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm); + usz jump = l * csz; + usz xi = take*ind; + usz ri = 0; + for (usz i = 0; i < cam; i++) { + mut_copyG(rm, ri, x, xi, take); + xi+= jump; + ri+= take; + } + ra = mut_fp(rm); + } else if (xe==el_B) { + assert(take == 1); SGet(x) - HArr_p rp = m_harrUv(cam); - for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*jump+n); + HArr_p rp = m_harrUv(ria); + for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind); NOGC_E; ra = (Arr*)rp.c; } else { - void* rp = m_tyarrlbp(&ra, elwBitLog(xe), cam, el2t(xe)); + void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe)); void* xp = tyany_ptr(x); - switch(xe) { - case el_bit: + switch(xl) { + case 0: #if SINGELI - if (jump < 64) si_select_cells_bit_lt64(xp, rp, cam, jump, n); + if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind); else #endif - for (usz i=0; if; goto const_f; } usz *sh = SH(x); if (((rtid==n_fold && cr==1) || rtid==n_insert) && TI(x,elType)!=el_B - && isFun(fd->f) && 1==shProd(sh, k+1, xr) && sh[k] > 0) { + && isFun(fd->f) && sh[k] > 0) { usz m = sh[k]; u8 frtid = v(fd->f)->flags-1; if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false); if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false); - if (isPervasiveDyExt(fd->f)) { + if (isPervasiveDyExt(fd->f) && 1==shProd(sh, k+1, xr)) { if (TI(x,elType)==el_bit) { incG(x); // keep shape alive B r = fold_rows_bit(fd, x, shProd(sh, 0, k), m); diff --git a/src/core/arrFns.h b/src/core/arrFns.h index 1257c9e0..b020be2c 100644 --- a/src/core/arrFns.h +++ b/src/core/arrFns.h @@ -73,11 +73,15 @@ SHOULD_INLINE void arr_check_size(u64 sz, u8 type, usz ia) { #endif } // Log of width in bits: max of 7, and also return 7 if not power of 2 -SHOULD_INLINE u8 cellWidthLog(B x) { +SHOULD_INLINE u8 multWidthLog(usz n, u8 lw) { // Of n elements, 1<>lw); // Max of 7; also handle n==0 +} +SHOULD_INLINE u8 kCellWidthLog(B x, ur k) { assert(isArr(x) && RNK(x)>=1); u8 lw = arrTypeBitsLog(TY(x)); - if (LIKELY(RNK(x)==1)) return lw; - usz csz = arr_csz(x); - if (csz & (csz-1)) return 7; // Not power of 2 - return lw + CTZ(csz | 128>>lw); // Max of 7; also handle csz==0 + ur xr = RNK(x); + if (LIKELY(xr <= k)) return lw; + return multWidthLog(shProd(SH(x), k, xr), lw); } +SHOULD_INLINE u8 cellWidthLog(B x) { return kCellWidthLog(x, 1); }