Faster boolean +˝˘, and ≠˝˘ on row length >64

This commit is contained in:
Marshall Lochbaum 2024-05-23 22:27:12 -04:00
parent ab4e5543a0
commit c76e175719
2 changed files with 124 additions and 4 deletions

View File

@ -7,7 +7,8 @@
B fne_c1(B, B);
B shape_c2(B, B, B);
B transp_c2(B, B, B);
B fold_rows(Md1D* d, B x); // from fold.c
B fold_rows(Md1D* d, B x); // from fold.c
B fold_rows_bit(Md1D* d, B x); // from fold.c
B takedrop_highrank(bool take, B w, B x); // from sfns.c
B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c
@ -455,7 +456,10 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
u8 frtid = v(fd->f)->flags-1;
if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false);
if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false);
if (isPervasiveDyExt(fd->f) && m <= 64 && m < sh[0]) return fold_rows(fd, x);
if (isPervasiveDyExt(fd->f)) {
if (TI(x,elType)==el_bit) { B r = fold_rows_bit(fd, x); if (!q_N(r)) return r; }
if (m <= 64 && m < sh[0]) return fold_rows(fd, x);
}
}
} else if (TY(f) == t_md2D) {
Md2D* fd = c(Md2D,f);

View File

@ -20,9 +20,13 @@
#include "../utils/includeSingeli.h"
#endif
static bool fold_ne(u64* x, u64 am) {
static u64 xor_words(u64* x, u64 l) {
u64 r = 0;
for (u64 i = 0; i < (am>>6); i++) r^= x[i];
for (u64 i = 0; i < l; i++) r^= x[i];
return r;
}
static bool fold_ne(u64* x, u64 am) {
u64 r = xor_words(x, am>>6);
if (am&63) r^= x[am>>6]<<(64-am & 63);
return POPC(r) & 1;
}
@ -432,3 +436,115 @@ B fold_rows(Md1D* fd, B x) {
return mut_fv(r);
}
}
B sum_rows_bit(B x) {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
u64* xp = bitarr_ptr(x);
if (m < 128) {
if (m == 2) return bi_N; // Transpose is faster
i8* rp; B r = m_i8arrv(&rp, n);
if (m <= 64) {
if (m%8 == 0) {
usz k = m/8; u64 b = (m==64? 0 : 1ull<<m)-1;
for (usz i=0; i<n; i++) rp[i] = POPC(b & *(u64*)((u8*)xp+k*i));
} else {
if (m<=58 || m==60) {
u64 b = (1ull<<m)-1;
for (usz i=0, j=0; i<n; i++, j+=m) {
u64 xw = *(u64*)((u8*)xp+j/8) >> (j%8);
rp[i] = POPC(b & xw);
}
} else {
// Row may not fit in an aligned word
// Read a word containing the last bit, combine with saved bits
u64 b = ~(~(u64)0 >> m);
u64 prev = 0;
for (usz i=0, j=m; i<n; i++, j+=m) {
u64 xw = ((u64*)((u8*)xp+(j+7)/8))[-1];
usz sh = (-j)%8;
rp[i] = POPC(b & (xw<<sh | prev));
prev = xw >> (m-sh);
}
}
}
} else { // 64<m<128, specialization of unaligned case below
u64 o = 0;
for (usz i=0, j=0; i<n; i++) {
u64 in = (i+1)*m;
u64 s = o + POPC(xp[j]);
usz jn = in/64;
u64 e = xp[jn];
s += POPC(e & ((u64)j - (u64)jn)); // mask is 0 if j==jn, or -1
o = POPC(e >> (in%64));
rp[i] = s - o;
j = jn+1;
}
}
decG(x); return r;
} else if (m < 1<<15) {
i16* rp; B r = m_i16arrv(&rp, n);
usz l = m/64;
if (m%64==0) {
for (usz i=0; i<n; i++) rp[i] = bit_sum(xp+l*i, m);
} else {
u64 o = 0; // Carry
for (usz i=0, j=0; i<n; i++) {
u64 in = (i+1)*m; // bit index of start of next row
u64 s = o + bit_sum(xp + j, 64*l);
usz jn = in/64;
u64 e = xp[jn];
s += POPC(e &- (u64)(jn >= j+l));
o = POPC(e >> (in%64));
rp[i] = s - o;
j = jn+1;
}
}
decG(x); return r;
} else {
return bi_N;
}
}
B xor_rows_bit(B x) {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
if (m <= 64) return bi_N;
u64* xp = bitarr_ptr(x);
u64* rp; B r = m_bitarrv(&rp, n);
u64 rw = 0; // Buffer for result bits
#define ADDBIT(I, BIT) \
rw = rw>>1 | (u64)(BIT)<<63; \
if ((I+1)%64==0) *rp++ = rw;
#define XOR_LOOP(LEN, MASK) \
u64 o = 0; /* Carry */ \
for (usz i=0, j=0; i<n; i++) { \
u64 in = (i+1)*m; /* bit index of start of next row */ \
u64 s = o ^ xor_words(xp + j, LEN); \
usz jn = in/64; \
u64 e = xp[jn]; \
s ^= e & MASK; \
o = e >> (in%64); \
ADDBIT(i, POPC(s ^ o)); \
j = jn+1; \
}
usz l = m/64;
if (m < 128) {
XOR_LOOP(1, (u64)j - (u64)jn)
} else if (m%64==0) {
for (usz i=0; i<n; i++) { ADDBIT(i, POPC(xor_words(xp+l*i, l))) }
} else {
XOR_LOOP(l, -(u64)(jn >= j+l))
}
usz q=(-n)%64; if (q) *rp = rw >> q;
#undef XOR_LOOP
#undef ADDBIT
decG(x); return r;
}
B fold_rows_bit(Md1D* fd, B x) {
assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit);
if (!v(fd->f)->flags) return bi_N;
u8 rtid = v(fd->f)->flags-1;
if (rtid==n_add) return sum_rows_bit(x);
if (rtid==n_ne ) return xor_rows_bit(x);
return bi_N;
}