Extend fold_rows to any rank, as long as stride is 1

This commit is contained in:
Marshall Lochbaum 2024-07-04 15:23:41 -04:00
parent 48d77e722f
commit 201fe1e57c
2 changed files with 28 additions and 19 deletions

View File

@ -11,7 +11,7 @@ B take_c2(B, B, B);
B join_c2(B, B, B);
// from fold.c:
B fold_rows(Md1D* d, B x);
B fold_rows(Md1D* d, B x, usz n, usz m);
B fold_rows_bit(Md1D* d, B x, usz n, usz m);
B insert_cells_join(B x, usz* xsh, ur cr, ur k);
B insert_cells_identity(B x, B f, usz* xsh, ur xr, ur k, u8 rtid);
@ -506,21 +506,25 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false);
if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false);
if (isPervasiveDyExt(fd->f) && 1==shProd(xsh, k+1, xr)) {
B r;
// special cases always return rank 1
// incG(x) preserves the shape to restore afterwards if needed
if (TI(x,elType)==el_bit) {
incG(x); // keep shape alive
B r = fold_rows_bit(fd, x, cam, m);
if (!q_N(r)) {
if (xr > 2) {
usz* rsh = arr_shAlloc(a(r), xr-1);
shcpy(rsh, xsh, k);
shcpy(rsh+k, xsh+k+1, xr-1-k);
}
decG(x); return r;
}
decG(x);
incG(x); r = fold_rows_bit(fd, x, cam, m);
if (q_N(r)) decG(x); // will try fold_rows
else goto finish_fold;
}
// TODO extend to any rank
if (xr==2 && k==1 && m<=64 && m<xsh[0]) return fold_rows(fd, x);
if (m<=64 && m<cam) {
incG(x); r = fold_rows (fd, x, cam, m);
}
else break;
finish_fold:
if (xr > 2) {
usz* rsh = arr_shAlloc(a(r), xr-1);
shcpy(rsh, xsh, k);
shcpy(rsh+k, xsh+k+1, xr-1-k);
}
decG(x); return r;
}
} break;
case n_scan: {

View File

@ -512,19 +512,24 @@ B insert_cells_identity(B x, B f, usz* xsh, ur xr, ur k, u8 rtid) {
decG(x); return taga(r);
}
// Arithmetic fold/insert on rows of flat rank-2 array x
// Arithmetic fold/insert on -k-cells of flat array x
// Return a vector regardless of argument shape
B transp_c1(B, B);
B join_c2(B, B, B);
B fold_rows(Md1D* fd, B x) {
assert(isArr(x) && RNK(x)==2);
B fold_rows(Md1D* fd, B x, usz n, usz m) {
assert(isArr(x) && IA(x)==n*m);
// Target block size trying to avoid power-of-two lengths, from:
// {𝕩/˜⌊´⊸= +˝˘ +˝¬∨`2|>⌊∘÷⟜2⍟(↕12) ⌊0.5+32÷˜𝕩÷⌜1+↕64} +⟜↕2⋆16
u64 block = (116053*8) >> arrTypeBitsLog(TY(x));
if (TI(x,elType)==el_bit || IA(x)/2 <= block) {
if (RNK(x) > 2) {
Arr* xc = cpyWithShape(x);
ShArr* sh = m_shArr(2);
sh->a[0] = n; sh->a[1] = m;
x = taga(arr_shReplace(xc, 2, sh));
}
x = C1(transp, x);
return insert_c1(fd, x);
} else {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
usz b = (block + m - 1) / m; // Normal block length
usz b_max = b + b/4; // Last block max length
MAKE_MUT(r, n); MUT_APPEND_INIT(r);