Faster boolean +˝˘, and ≠˝˘ on row length >64
This commit is contained in:
parent
ab4e5543a0
commit
c76e175719
@ -7,7 +7,8 @@
|
||||
B fne_c1(B, B);
|
||||
B shape_c2(B, B, B);
|
||||
B transp_c2(B, B, B);
|
||||
B fold_rows(Md1D* d, B x); // from fold.c
|
||||
B fold_rows(Md1D* d, B x); // from fold.c
|
||||
B fold_rows_bit(Md1D* d, B x); // from fold.c
|
||||
B takedrop_highrank(bool take, B w, B x); // from sfns.c
|
||||
B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c
|
||||
|
||||
@ -455,7 +456,10 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
|
||||
u8 frtid = v(fd->f)->flags-1;
|
||||
if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false);
|
||||
if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false);
|
||||
if (isPervasiveDyExt(fd->f) && m <= 64 && m < sh[0]) return fold_rows(fd, x);
|
||||
if (isPervasiveDyExt(fd->f)) {
|
||||
if (TI(x,elType)==el_bit) { B r = fold_rows_bit(fd, x); if (!q_N(r)) return r; }
|
||||
if (m <= 64 && m < sh[0]) return fold_rows(fd, x);
|
||||
}
|
||||
}
|
||||
} else if (TY(f) == t_md2D) {
|
||||
Md2D* fd = c(Md2D,f);
|
||||
|
||||
@ -20,9 +20,13 @@
|
||||
#include "../utils/includeSingeli.h"
|
||||
#endif
|
||||
|
||||
static bool fold_ne(u64* x, u64 am) {
|
||||
static u64 xor_words(u64* x, u64 l) {
|
||||
u64 r = 0;
|
||||
for (u64 i = 0; i < (am>>6); i++) r^= x[i];
|
||||
for (u64 i = 0; i < l; i++) r^= x[i];
|
||||
return r;
|
||||
}
|
||||
static bool fold_ne(u64* x, u64 am) {
|
||||
u64 r = xor_words(x, am>>6);
|
||||
if (am&63) r^= x[am>>6]<<(64-am & 63);
|
||||
return POPC(r) & 1;
|
||||
}
|
||||
@ -432,3 +436,115 @@ B fold_rows(Md1D* fd, B x) {
|
||||
return mut_fv(r);
|
||||
}
|
||||
}
|
||||
|
||||
B sum_rows_bit(B x) {
|
||||
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
|
||||
u64* xp = bitarr_ptr(x);
|
||||
if (m < 128) {
|
||||
if (m == 2) return bi_N; // Transpose is faster
|
||||
i8* rp; B r = m_i8arrv(&rp, n);
|
||||
if (m <= 64) {
|
||||
if (m%8 == 0) {
|
||||
usz k = m/8; u64 b = (m==64? 0 : 1ull<<m)-1;
|
||||
for (usz i=0; i<n; i++) rp[i] = POPC(b & *(u64*)((u8*)xp+k*i));
|
||||
} else {
|
||||
if (m<=58 || m==60) {
|
||||
u64 b = (1ull<<m)-1;
|
||||
for (usz i=0, j=0; i<n; i++, j+=m) {
|
||||
u64 xw = *(u64*)((u8*)xp+j/8) >> (j%8);
|
||||
rp[i] = POPC(b & xw);
|
||||
}
|
||||
} else {
|
||||
// Row may not fit in an aligned word
|
||||
// Read a word containing the last bit, combine with saved bits
|
||||
u64 b = ~(~(u64)0 >> m);
|
||||
u64 prev = 0;
|
||||
for (usz i=0, j=m; i<n; i++, j+=m) {
|
||||
u64 xw = ((u64*)((u8*)xp+(j+7)/8))[-1];
|
||||
usz sh = (-j)%8;
|
||||
rp[i] = POPC(b & (xw<<sh | prev));
|
||||
prev = xw >> (m-sh);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // 64<m<128, specialization of unaligned case below
|
||||
u64 o = 0;
|
||||
for (usz i=0, j=0; i<n; i++) {
|
||||
u64 in = (i+1)*m;
|
||||
u64 s = o + POPC(xp[j]);
|
||||
usz jn = in/64;
|
||||
u64 e = xp[jn];
|
||||
s += POPC(e & ((u64)j - (u64)jn)); // mask is 0 if j==jn, or -1
|
||||
o = POPC(e >> (in%64));
|
||||
rp[i] = s - o;
|
||||
j = jn+1;
|
||||
}
|
||||
}
|
||||
decG(x); return r;
|
||||
} else if (m < 1<<15) {
|
||||
i16* rp; B r = m_i16arrv(&rp, n);
|
||||
usz l = m/64;
|
||||
if (m%64==0) {
|
||||
for (usz i=0; i<n; i++) rp[i] = bit_sum(xp+l*i, m);
|
||||
} else {
|
||||
u64 o = 0; // Carry
|
||||
for (usz i=0, j=0; i<n; i++) {
|
||||
u64 in = (i+1)*m; // bit index of start of next row
|
||||
u64 s = o + bit_sum(xp + j, 64*l);
|
||||
usz jn = in/64;
|
||||
u64 e = xp[jn];
|
||||
s += POPC(e &- (u64)(jn >= j+l));
|
||||
o = POPC(e >> (in%64));
|
||||
rp[i] = s - o;
|
||||
j = jn+1;
|
||||
}
|
||||
}
|
||||
decG(x); return r;
|
||||
} else {
|
||||
return bi_N;
|
||||
}
|
||||
}
|
||||
|
||||
B xor_rows_bit(B x) {
|
||||
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
|
||||
if (m <= 64) return bi_N;
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64* rp; B r = m_bitarrv(&rp, n);
|
||||
u64 rw = 0; // Buffer for result bits
|
||||
#define ADDBIT(I, BIT) \
|
||||
rw = rw>>1 | (u64)(BIT)<<63; \
|
||||
if ((I+1)%64==0) *rp++ = rw;
|
||||
#define XOR_LOOP(LEN, MASK) \
|
||||
u64 o = 0; /* Carry */ \
|
||||
for (usz i=0, j=0; i<n; i++) { \
|
||||
u64 in = (i+1)*m; /* bit index of start of next row */ \
|
||||
u64 s = o ^ xor_words(xp + j, LEN); \
|
||||
usz jn = in/64; \
|
||||
u64 e = xp[jn]; \
|
||||
s ^= e & MASK; \
|
||||
o = e >> (in%64); \
|
||||
ADDBIT(i, POPC(s ^ o)); \
|
||||
j = jn+1; \
|
||||
}
|
||||
usz l = m/64;
|
||||
if (m < 128) {
|
||||
XOR_LOOP(1, (u64)j - (u64)jn)
|
||||
} else if (m%64==0) {
|
||||
for (usz i=0; i<n; i++) { ADDBIT(i, POPC(xor_words(xp+l*i, l))) }
|
||||
} else {
|
||||
XOR_LOOP(l, -(u64)(jn >= j+l))
|
||||
}
|
||||
usz q=(-n)%64; if (q) *rp = rw >> q;
|
||||
#undef XOR_LOOP
|
||||
#undef ADDBIT
|
||||
decG(x); return r;
|
||||
}
|
||||
|
||||
B fold_rows_bit(Md1D* fd, B x) {
|
||||
assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit);
|
||||
if (!v(fd->f)->flags) return bi_N;
|
||||
u8 rtid = v(fd->f)->flags-1;
|
||||
if (rtid==n_add) return sum_rows_bit(x);
|
||||
if (rtid==n_ne ) return xor_rows_bit(x);
|
||||
return bi_N;
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user