diff --git a/src/builtins/cells.c b/src/builtins/cells.c index 47e8da8e..00e55267 100644 --- a/src/builtins/cells.c +++ b/src/builtins/cells.c @@ -7,7 +7,8 @@ B fne_c1(B, B); B shape_c2(B, B, B); B transp_c2(B, B, B); -B fold_rows(Md1D* d, B x); // from fold.c +B fold_rows(Md1D* d, B x); // from fold.c +B fold_rows_bit(Md1D* d, B x); // from fold.c B takedrop_highrank(bool take, B w, B x); // from sfns.c B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c @@ -455,7 +456,10 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr u8 frtid = v(fd->f)->flags-1; if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false); if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false); - if (isPervasiveDyExt(fd->f) && m <= 64 && m < sh[0]) return fold_rows(fd, x); + if (isPervasiveDyExt(fd->f)) { + if (TI(x,elType)==el_bit) { B r = fold_rows_bit(fd, x); if (!q_N(r)) return r; } + if (m <= 64 && m < sh[0]) return fold_rows(fd, x); + } } } else if (TY(f) == t_md2D) { Md2D* fd = c(Md2D,f); diff --git a/src/builtins/fold.c b/src/builtins/fold.c index 1a94a6ae..272772b6 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -20,9 +20,13 @@ #include "../utils/includeSingeli.h" #endif -static bool fold_ne(u64* x, u64 am) { +static u64 xor_words(u64* x, u64 l) { u64 r = 0; - for (u64 i = 0; i < (am>>6); i++) r^= x[i]; + for (u64 i = 0; i < l; i++) r^= x[i]; + return r; +} +static bool fold_ne(u64* x, u64 am) { + u64 r = xor_words(x, am>>6); if (am&63) r^= x[am>>6]<<(64-am & 63); return POPC(r) & 1; } @@ -432,3 +436,115 @@ B fold_rows(Md1D* fd, B x) { return mut_fv(r); } } + +B sum_rows_bit(B x) { + usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; + u64* xp = bitarr_ptr(x); + if (m < 128) { + if (m == 2) return bi_N; // Transpose is faster + i8* rp; B r = m_i8arrv(&rp, n); + if (m <= 64) { + if (m%8 == 0) { + usz k = m/8; u64 b = (m==64? 0 : 1ull<> (j%8); + rp[i] = POPC(b & xw); + } + } else { + // Row may not fit in an aligned word + // Read a word containing the last bit, combine with saved bits + u64 b = ~(~(u64)0 >> m); + u64 prev = 0; + for (usz i=0, j=m; i> (m-sh); + } + } + } + } else { // 64> (in%64)); + rp[i] = s - o; + j = jn+1; + } + } + decG(x); return r; + } else if (m < 1<<15) { + i16* rp; B r = m_i16arrv(&rp, n); + usz l = m/64; + if (m%64==0) { + for (usz i=0; i= j+l)); + o = POPC(e >> (in%64)); + rp[i] = s - o; + j = jn+1; + } + } + decG(x); return r; + } else { + return bi_N; + } +} + +B xor_rows_bit(B x) { + usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; + if (m <= 64) return bi_N; + u64* xp = bitarr_ptr(x); + u64* rp; B r = m_bitarrv(&rp, n); + u64 rw = 0; // Buffer for result bits + #define ADDBIT(I, BIT) \ + rw = rw>>1 | (u64)(BIT)<<63; \ + if ((I+1)%64==0) *rp++ = rw; + #define XOR_LOOP(LEN, MASK) \ + u64 o = 0; /* Carry */ \ + for (usz i=0, j=0; i> (in%64); \ + ADDBIT(i, POPC(s ^ o)); \ + j = jn+1; \ + } + usz l = m/64; + if (m < 128) { + XOR_LOOP(1, (u64)j - (u64)jn) + } else if (m%64==0) { + for (usz i=0; i= j+l)) + } + usz q=(-n)%64; if (q) *rp = rw >> q; + #undef XOR_LOOP + #undef ADDBIT + decG(x); return r; +} + +B fold_rows_bit(Md1D* fd, B x) { + assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit); + if (!v(fd->f)->flags) return bi_N; + u8 rtid = v(fd->f)->flags-1; + if (rtid==n_add) return sum_rows_bit(x); + if (rtid==n_ne ) return xor_rows_bit(x); + return bi_N; +}