From cb1b72fbb239752962e95cfd00c13c9f6fea53bc Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Tue, 18 Jun 2024 07:46:18 -0400 Subject: [PATCH] =?UTF-8?q?Extend=20boolean=20F=CB=9D=CB=98=20special=20co?= =?UTF-8?q?de=20to=20any=20ranks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/cells.c | 34 +++++++++++++++++++++++++--------- src/builtins/fold.c | 12 ++++++------ 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index cff6890e..244f5e0e 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -7,10 +7,10 @@ B fne_c1(B, B); B shape_c2(B, B, B); B transp_c2(B, B, B); -B fold_rows(Md1D* d, B x); // from fold.c -B fold_rows_bit(Md1D* d, B x); // from fold.c -B scan_rows_bit(u8, B x, usz m); // from scan.c -B takedrop_highrank(bool take, B w, B x); // from sfns.c +B fold_rows(Md1D* d, B x); // from fold.c +B fold_rows_bit(Md1D* d, B x, usz n, usz m); // from fold.c +B scan_rows_bit(u8, B x, usz m); // from scan.c +B takedrop_highrank(bool take, B w, B x); // from sfns.c B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c // X - variable name; XSH - its shape; K - number of leading axes that get iterated over; SLN - number of slices that will be made; DX - additional refcount count to add to x @@ -452,14 +452,29 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr Md1D* fd = c(Md1D,f); u8 rtid = fd->m1->flags-1; if (rtid==n_const) { f=fd->f; goto const_f; } - if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B && k==1 && xr==2 && isFun(fd->f)) { // TODO extend to any rank x with cr==1 - usz *sh = SH(x); usz m = sh[1]; + usz *sh = SH(x); + if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B + && isFun(fd->f) && 1==shProd(sh, k+1, xr)) { + usz m = sh[k]; u8 frtid = v(fd->f)->flags-1; if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false); if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false); if (isPervasiveDyExt(fd->f)) { - if (TI(x,elType)==el_bit) { B r = fold_rows_bit(fd, x); if (!q_N(r)) return r; } - if (m <= 64 && m < sh[0]) return fold_rows(fd, x); + if (TI(x,elType)==el_bit) { + incG(x); // keep shape alive + B r = fold_rows_bit(fd, x, shProd(sh, 0, k), m); + if (!q_N(r)) { + if (xr > 2) { + usz* rsh = arr_shAlloc(a(r), xr-1); + shcpy(rsh, sh, k); + shcpy(rsh+k, sh+k+1, xr-1-k); + } + decG(x); return r; + } + decG(x); + } + // TODO extend to any rank + if (xr==2 && k==1 && m<=64 && mf)) goto base; u8 frtid = v(fd->f)->flags-1; if (frtid==n_rtack) return x; - if (TI(x,elType)==el_bit && (isPervasiveDyExt(fd->f)||frtid==n_ltack) && 1==shProd(sh, k+1, xr)) { + if (TI(x,elType)==el_bit && (isPervasiveDyExt(fd->f)||frtid==n_ltack) + && 1==shProd(sh, k+1, xr)) { B r = scan_rows_bit(frtid, x, m); if (!q_N(r)) return r; } } diff --git a/src/builtins/fold.c b/src/builtins/fold.c index 8a41e8d1..a02d953d 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -440,8 +440,7 @@ B fold_rows(Md1D* fd, B x) { } } -B sum_rows_bit(B x) { - usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; +B sum_rows_bit(B x, usz n, usz m) { u64* xp = bitarr_ptr(x); if (m < 128) { if (m == 2) return bi_N; // Transpose is faster @@ -508,15 +507,16 @@ B sum_rows_bit(B x) { } } -B fold_rows_bit(Md1D* fd, B x) { - assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit); +// Fold n cells of size m, stride 1 +// Return a vector regardless of argument shape, or bi_N if not handled +B fold_rows_bit(Md1D* fd, B x, usz n, usz m) { + assert(isArr(x) && TI(x,elType)==el_bit && IA(x)==n*m); if (!v(fd->f)->flags) return bi_N; u8 rtid = v(fd->f)->flags-1; - if (rtid==n_add) return sum_rows_bit(x); + if (rtid==n_add) return sum_rows_bit(x, n, m); #if SINGELI if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) { bool andor = rtid==n_or|rtid==n_and; - usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; if (andor && m < 256) while (m%8 == 0) { usz f = CTZ(m|32); m >>= f; usz c = m*n;