Extend boolean F˝˘ special code to any ranks

This commit is contained in:
Marshall Lochbaum 2024-06-18 07:46:18 -04:00
parent 1e6c7057e8
commit cb1b72fbb2
2 changed files with 31 additions and 15 deletions

View File

@ -7,10 +7,10 @@
B fne_c1(B, B); B fne_c1(B, B);
B shape_c2(B, B, B); B shape_c2(B, B, B);
B transp_c2(B, B, B); B transp_c2(B, B, B);
B fold_rows(Md1D* d, B x); // from fold.c B fold_rows(Md1D* d, B x); // from fold.c
B fold_rows_bit(Md1D* d, B x); // from fold.c B fold_rows_bit(Md1D* d, B x, usz n, usz m); // from fold.c
B scan_rows_bit(u8, B x, usz m); // from scan.c B scan_rows_bit(u8, B x, usz m); // from scan.c
B takedrop_highrank(bool take, B w, B x); // from sfns.c B takedrop_highrank(bool take, B w, B x); // from sfns.c
B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c
// X - variable name; XSH - its shape; K - number of leading axes that get iterated over; SLN - number of slices that will be made; DX - additional refcount count to add to x // X - variable name; XSH - its shape; K - number of leading axes that get iterated over; SLN - number of slices that will be made; DX - additional refcount count to add to x
@ -452,14 +452,29 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
Md1D* fd = c(Md1D,f); Md1D* fd = c(Md1D,f);
u8 rtid = fd->m1->flags-1; u8 rtid = fd->m1->flags-1;
if (rtid==n_const) { f=fd->f; goto const_f; } if (rtid==n_const) { f=fd->f; goto const_f; }
if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B && k==1 && xr==2 && isFun(fd->f)) { // TODO extend to any rank x with cr==1 usz *sh = SH(x);
usz *sh = SH(x); usz m = sh[1]; if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B
&& isFun(fd->f) && 1==shProd(sh, k+1, xr)) {
usz m = sh[k];
u8 frtid = v(fd->f)->flags-1; u8 frtid = v(fd->f)->flags-1;
if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false); if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false);
if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false); if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false);
if (isPervasiveDyExt(fd->f)) { if (isPervasiveDyExt(fd->f)) {
if (TI(x,elType)==el_bit) { B r = fold_rows_bit(fd, x); if (!q_N(r)) return r; } if (TI(x,elType)==el_bit) {
if (m <= 64 && m < sh[0]) return fold_rows(fd, x); incG(x); // keep shape alive
B r = fold_rows_bit(fd, x, shProd(sh, 0, k), m);
if (!q_N(r)) {
if (xr > 2) {
usz* rsh = arr_shAlloc(a(r), xr-1);
shcpy(rsh, sh, k);
shcpy(rsh+k, sh+k+1, xr-1-k);
}
decG(x); return r;
}
decG(x);
}
// TODO extend to any rank
if (xr==2 && k==1 && m<=64 && m<sh[0]) return fold_rows(fd, x);
} }
} }
if (rtid==n_scan) { if (rtid==n_scan) {
@ -469,7 +484,8 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
if (!isFun(fd->f)) goto base; if (!isFun(fd->f)) goto base;
u8 frtid = v(fd->f)->flags-1; u8 frtid = v(fd->f)->flags-1;
if (frtid==n_rtack) return x; if (frtid==n_rtack) return x;
if (TI(x,elType)==el_bit && (isPervasiveDyExt(fd->f)||frtid==n_ltack) && 1==shProd(sh, k+1, xr)) { if (TI(x,elType)==el_bit && (isPervasiveDyExt(fd->f)||frtid==n_ltack)
&& 1==shProd(sh, k+1, xr)) {
B r = scan_rows_bit(frtid, x, m); if (!q_N(r)) return r; B r = scan_rows_bit(frtid, x, m); if (!q_N(r)) return r;
} }
} }

View File

@ -440,8 +440,7 @@ B fold_rows(Md1D* fd, B x) {
} }
} }
B sum_rows_bit(B x) { B sum_rows_bit(B x, usz n, usz m) {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
u64* xp = bitarr_ptr(x); u64* xp = bitarr_ptr(x);
if (m < 128) { if (m < 128) {
if (m == 2) return bi_N; // Transpose is faster if (m == 2) return bi_N; // Transpose is faster
@ -508,15 +507,16 @@ B sum_rows_bit(B x) {
} }
} }
B fold_rows_bit(Md1D* fd, B x) { // Fold n cells of size m, stride 1
assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit); // Return a vector regardless of argument shape, or bi_N if not handled
B fold_rows_bit(Md1D* fd, B x, usz n, usz m) {
assert(isArr(x) && TI(x,elType)==el_bit && IA(x)==n*m);
if (!v(fd->f)->flags) return bi_N; if (!v(fd->f)->flags) return bi_N;
u8 rtid = v(fd->f)->flags-1; u8 rtid = v(fd->f)->flags-1;
if (rtid==n_add) return sum_rows_bit(x); if (rtid==n_add) return sum_rows_bit(x, n, m);
#if SINGELI #if SINGELI
if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) { if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) {
bool andor = rtid==n_or|rtid==n_and; bool andor = rtid==n_or|rtid==n_and;
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
if (andor && m < 256) while (m%8 == 0) { if (andor && m < 256) while (m%8 == 0) {
usz f = CTZ(m|32); usz f = CTZ(m|32);
m >>= f; usz c = m*n; m >>= f; usz c = m*n;