From 1e6c7057e8fe2ce1a94bf0ba411792c4817e42b6 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Mon, 17 Jun 2024 21:51:32 -0400 Subject: [PATCH] =?UTF-8?q?Extend=20boolean=20F`=CB=98=20special=20code=20?= =?UTF-8?q?to=20any=20frame=20and=20cell=20rank?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/cells.c | 10 +++++----- src/builtins/scan.c | 21 +++++++++++---------- src/singeli/src/scan.singeli | 14 ++++++-------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/builtins/cells.c b/src/builtins/cells.c index 0fb83ed6..cff6890e 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -7,9 +7,9 @@ B fne_c1(B, B); B shape_c2(B, B, B); B transp_c2(B, B, B); -B fold_rows(Md1D* d, B x); // from fold.c -B fold_rows_bit(Md1D* d, B x); // from fold.c -B scan_rows_bit(u8, B x); // from scan.c +B fold_rows(Md1D* d, B x); // from fold.c +B fold_rows_bit(Md1D* d, B x); // from fold.c +B scan_rows_bit(u8, B x, usz m); // from scan.c B takedrop_highrank(bool take, B w, B x); // from sfns.c B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c @@ -469,8 +469,8 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr if (!isFun(fd->f)) goto base; u8 frtid = v(fd->f)->flags-1; if (frtid==n_rtack) return x; - if (k==1 && xr==2 && (isPervasiveDyExt(fd->f)||frtid==n_ltack) && TI(x,elType)==el_bit) { - B r = scan_rows_bit(frtid, x); if (!q_N(r)) return r; + if (TI(x,elType)==el_bit && (isPervasiveDyExt(fd->f)||frtid==n_ltack) && 1==shProd(sh, k+1, xr)) { + B r = scan_rows_bit(frtid, x, m); if (!q_N(r)) return r; } } } else if (TY(f) == t_md2D) { diff --git a/src/builtins/scan.c b/src/builtins/scan.c index c4fb2885..0ebd813e 100644 --- a/src/builtins/scan.c +++ b/src/builtins/scan.c @@ -343,25 +343,26 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f; return withFill(r.b, wf); } -B scan_rows_bit(u8 rtid, B x) { - assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit); +// scan cells of size m, stride 1 +B scan_rows_bit(u8 rtid, B x, usz m) { + assert(isArr(x) && TI(x,elType)==el_bit); #if SINGELI switch (rtid) { default: return bi_N; - case n_eq: return bit_negate(scan_rows_bit(n_ne, bit_negate(x))); + case n_eq: return bit_negate(scan_rows_bit(n_ne, bit_negate(x), m)); case n_and: case n_or: case n_ne: case n_ltack: { - usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; + usz ia = IA(x); u64* xp = bitarr_ptr(x); u64* rp; B r = m_bitarrc(&rp, x); switch (rtid) { default:UD; - case n_and: si_scan_rows_and (xp, rp, n, m); break; - case n_or: si_scan_rows_or (xp, rp, n, m); break; - case n_ne: si_scan_rows_ne (xp, rp, n, m); break; - case n_ltack: si_scan_rows_ltack(xp, rp, n, m); break; + case n_and: si_scan_rows_and (xp, rp, ia, m); break; + case n_or: si_scan_rows_or (xp, rp, ia, m); break; + case n_ne: si_scan_rows_ne (xp, rp, ia, m); break; + case n_ltack: si_scan_rows_ltack(xp, rp, ia, m); break; } decG(x); return r; } case n_add: case n_sub: { - usz ia = IA(x); usz m = SH(x)[1]; + usz ia = IA(x); if (m >= 128) return bi_N; usz bl = 128; // block size i8 buf[bl]; i8 c = 0; @@ -383,7 +384,7 @@ B scan_rows_bit(u8 rtid, B x) { if (j == e) { j += m; c = 0; } else c = rp[e-1]; } if (rtid!=n_sub) { decG(x); return r; } - return C2(sub, C2(mul, m_f64(2), scan_rows_bit(n_ltack, x)), r); + return C2(sub, C2(mul, m_f64(2), scan_rows_bit(n_ltack, x, m)), r); } } #else diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index 0f8417f5..3d15a81b 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -327,10 +327,10 @@ def avx2_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = { } } -fn scan_rows_andor{id}(src:*u64, dst:*u64, n:usz, l:usz) : void = { +fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { def qand = not id assert{l > 0} - nw := cdiv{n*l, 64} + nw := cdiv{nl, 64} def res_m1{x,c,m} = { # result word with carry c, popc{m}<=1 if (qand) x &~ ((x+c) & (x+m)) else x | ((-x-c) &~ (x-m)) @@ -385,7 +385,7 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, n:usz, l:usz) : void = { wn:usz = 0 # starting word of next row c:u64 = id # carry def word{bit} = bit * ((1<<64) - 1) - @for (n) { + we:= nl/64; while (wn < we) { iw:= wn r := res_m1{load{src, iw}, c, u64~~1 << (i%64)} store{dst, iw, r}; ++iw @@ -410,10 +410,9 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, n:usz, l:usz) : void = { } } -fn scan_rows_neq(x:*u64, r:*u64, n:usz, l:usz) : void = { +fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { def scan_word = prefix_byshift{^, <<} assert{l > 0} - nl := n*l nw := cdiv{nl, 64} if (l < 64) { if ((l & (l-1)) == 0) { @@ -460,10 +459,9 @@ fn scan_rows_neq(x:*u64, r:*u64, n:usz, l:usz) : void = { } } -fn scan_rows_left(x:*u64, r:*u64, n:usz, l:usz) : void = { +fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = { def scan_word = prefix_byshift{^, <<} assert{l > 0} - nl := n*l nw := cdiv{nl, 64} if (l < 64) { if ((l & (l-1)) == 0) { @@ -484,7 +482,7 @@ fn scan_rows_left(x:*u64, r:*u64, n:usz, l:usz) : void = { i :usz = 0 # row bit index wn:usz = 0 # starting word of next row c:u64 = 0 # carry - @for (n) { + we:= nl/64; while (wn < we) { iw:= wn m := u64~~1 << (i%64) xw:= -(load{x, iw} & m)