From 201fe1e57c227859956a80df67e3ffa339fdb457 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Thu, 4 Jul 2024 15:23:41 -0400 Subject: [PATCH] Extend fold_rows to any rank, as long as stride is 1 --- src/builtins/cells.c | 32 ++++++++++++++++++-------------- src/builtins/fold.c | 15 ++++++++++----- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index 88866a44..f7a0cd35 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -11,7 +11,7 @@ B take_c2(B, B, B); B join_c2(B, B, B); // from fold.c: -B fold_rows(Md1D* d, B x); +B fold_rows(Md1D* d, B x, usz n, usz m); B fold_rows_bit(Md1D* d, B x, usz n, usz m); B insert_cells_join(B x, usz* xsh, ur cr, ur k); B insert_cells_identity(B x, B f, usz* xsh, ur xr, ur k, u8 rtid); @@ -506,21 +506,25 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false); if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false); if (isPervasiveDyExt(fd->f) && 1==shProd(xsh, k+1, xr)) { + B r; + // special cases always return rank 1 + // incG(x) preserves the shape to restore afterwards if needed if (TI(x,elType)==el_bit) { - incG(x); // keep shape alive - B r = fold_rows_bit(fd, x, cam, m); - if (!q_N(r)) { - if (xr > 2) { - usz* rsh = arr_shAlloc(a(r), xr-1); - shcpy(rsh, xsh, k); - shcpy(rsh+k, xsh+k+1, xr-1-k); - } - decG(x); return r; - } - decG(x); + incG(x); r = fold_rows_bit(fd, x, cam, m); + if (q_N(r)) decG(x); // will try fold_rows + else goto finish_fold; } - // TODO extend to any rank - if (xr==2 && k==1 && m<=64 && m 2) { + usz* rsh = arr_shAlloc(a(r), xr-1); + shcpy(rsh, xsh, k); + shcpy(rsh+k, xsh+k+1, xr-1-k); + } + decG(x); return r; } } break; case n_scan: { diff --git a/src/builtins/fold.c b/src/builtins/fold.c index d1049cca..294215b2 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -512,19 +512,24 @@ B insert_cells_identity(B x, B f, usz* xsh, ur xr, ur k, u8 rtid) { decG(x); return taga(r); } -// Arithmetic fold/insert on rows of flat rank-2 array x +// Arithmetic fold/insert on -k-cells of flat array x +// Return a vector regardless of argument shape B transp_c1(B, B); -B join_c2(B, B, B); -B fold_rows(Md1D* fd, B x) { - assert(isArr(x) && RNK(x)==2); +B fold_rows(Md1D* fd, B x, usz n, usz m) { + assert(isArr(x) && IA(x)==n*m); // Target block size trying to avoid power-of-two lengths, from: // {𝕩/˜⌊´⊸= +˝˘ +˝¬∨`2|>⌊∘÷⟜2⍟(↕12) ⌊0.5+32÷˜𝕩÷⌜1+↕64} +⟜↕2⋆16 u64 block = (116053*8) >> arrTypeBitsLog(TY(x)); if (TI(x,elType)==el_bit || IA(x)/2 <= block) { + if (RNK(x) > 2) { + Arr* xc = cpyWithShape(x); + ShArr* sh = m_shArr(2); + sh->a[0] = n; sh->a[1] = m; + x = taga(arr_shReplace(xc, 2, sh)); + } x = C1(transp, x); return insert_c1(fd, x); } else { - usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; usz b = (block + m - 1) / m; // Normal block length usz b_max = b + b/4; // Last block max length MAKE_MUT(r, n); MUT_APPEND_INIT(r);