Boolean and, or, eq folds for row length >64

This commit is contained in:
Marshall Lochbaum 2024-05-25 07:04:46 -04:00
parent c76e175719
commit d29b4df50c

View File

@ -505,15 +505,12 @@ B sum_rows_bit(B x) {
}
}
B xor_rows_bit(B x) {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
if (m <= 64) return bi_N;
u64* xp = bitarr_ptr(x);
u64* rp; B r = m_bitarrv(&rp, n);
void xor_rows_bit(u64* xp, u64* rp, usz n, usz m, bool eq) {
u64 rw = 0; // Buffer for result bits
u64 rx = -(u64)(eq &~ m); // ne to eq conversion if needed
#define ADDBIT(I, BIT) \
rw = rw>>1 | (u64)(BIT)<<63; \
if ((I+1)%64==0) *rp++ = rw;
if ((I+1)%64==0) *rp++ = rw ^ rx;
#define XOR_LOOP(LEN, MASK) \
u64 o = 0; /* Carry */ \
for (usz i=0, j=0; i<n; i++) { \
@ -534,10 +531,50 @@ B xor_rows_bit(B x) {
} else {
XOR_LOOP(l, -(u64)(jn >= j+l))
}
usz q=(-n)%64; if (q) *rp = rw >> q;
usz q=(-n)%64; if (q) *rp = (rw^rx) >> q;
#undef XOR_LOOP
#undef ADDBIT
decG(x); return r;
}
void or_rows_bit(u64* xp, u64* rp, usz n, usz m, u64 and) {
u64 rw = 0; // Buffer for result bits
#define ADDBIT(I, BIT, XOR) \
rw = rw>>1 | (u64)(BIT)<<63; \
if ((I+1)%64==0) *rp++ = XOR ^ rw;
if (m < 128) {
usz c = and? m-1 : 0;
u64 o = 0;
for (usz i=0, j=0; i<n; i++) {
u64 in = (i+1)*m;
u64 s = o + POPC(xp[j]);
usz jn = in/64;
u64 e = xp[jn];
s += POPC(e & ((u64)j - (u64)jn)); // mask is 0 if j==jn, or -1
o = POPC(e >> (in%64));
ADDBIT(i, s > c+o, 0);
j = jn+1;
}
} else {
u64 rx = -and; u64 id = ~rx;
u64 o = 0; // Saved bits
for (usz i=0, j=0; i<n; i++) {
u64 in = (i+1)*m;
usz jn = in/64;
u64 e = xp[jn] ^ rx;
u64 m = ~(u64)0 << (in%64);
u64 rb = 1;
if ((o | (e &~ m)) == 0) {
rb = 0;
for (usz i = j; i < jn-1; i++) if (xp[i]^id) { rb = 1; break; }
}
o = e & m;
ADDBIT(i, rb, rx);
j = jn+1;
}
rw ^= rx;
}
usz q=(-n)%64; if (q) *rp = rw >> q;
#undef ADDBIT
}
B fold_rows_bit(Md1D* fd, B x) {
@ -545,6 +582,14 @@ B fold_rows_bit(Md1D* fd, B x) {
if (!v(fd->f)->flags) return bi_N;
u8 rtid = v(fd->f)->flags-1;
if (rtid==n_add) return sum_rows_bit(x);
if (rtid==n_ne ) return xor_rows_bit(x);
if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
if (m <= 64) return bi_N;
u64* xp = bitarr_ptr(x);
u64* rp; B r = m_bitarrv(&rp, n);
if (rtid==n_ne|rtid==n_eq) xor_rows_bit(xp, rp, n, m, rtid==n_eq);
else or_rows_bit(xp, rp, n, m, rtid==n_and);
decG(x); return r;
}
return bi_N;
}