Move logical fold-rows functions to Singeli

This commit is contained in:
Marshall Lochbaum 2024-05-25 15:46:19 -04:00
parent 2590222988
commit 7b4468c394
4 changed files with 94 additions and 76 deletions

View File

@ -505,91 +505,21 @@ B sum_rows_bit(B x) {
}
}
void xor_rows_bit(u64* xp, u64* rp, usz n, usz m, bool eq) {
u64 rw = 0; // Buffer for result bits
u64 rx = -(u64)(eq &~ m); // ne to eq conversion if needed
#define ADDBIT(I, BIT) \
rw = rw>>1 | (u64)(BIT)<<63; \
if ((I+1)%64==0) *rp++ = rw ^ rx;
#define XOR_LOOP(LEN, MASK) \
u64 o = 0; /* Carry */ \
for (usz i=0, j=0; i<n; i++) { \
u64 in = (i+1)*m; /* bit index of start of next row */ \
u64 s = o ^ xor_words(xp + j, LEN); \
usz jn = in/64; \
u64 e = xp[jn]; \
s ^= e & MASK; \
o = e >> (in%64); \
ADDBIT(i, POPC(s ^ o)); \
j = jn+1; \
}
usz l = m/64;
if (m < 128) {
XOR_LOOP(1, (u64)j - (u64)jn)
} else if (m%64==0) {
for (usz i=0; i<n; i++) { ADDBIT(i, POPC(xor_words(xp+l*i, l))) }
} else {
XOR_LOOP(l, -(u64)(jn >= j+l))
}
usz q=(-n)%64; if (q) *rp = (rw^rx) >> q;
#undef XOR_LOOP
#undef ADDBIT
}
void or_rows_bit(u64* xp, u64* rp, usz n, usz m, u64 and) {
u64 rw = 0; // Buffer for result bits
#define ADDBIT(I, BIT, XOR) \
rw = rw>>1 | (u64)(BIT)<<63; \
if ((I+1)%64==0) *rp++ = XOR ^ rw;
if (m < 128) {
usz c = and? m-1 : 0;
u64 o = 0;
for (usz i=0, j=0; i<n; i++) {
u64 in = (i+1)*m;
u64 s = o + POPC(xp[j]);
usz jn = in/64;
u64 e = xp[jn];
s += POPC(e & ((u64)j - (u64)jn)); // mask is 0 if j==jn, or -1
o = POPC(e >> (in%64));
ADDBIT(i, s > c+o, 0);
j = jn+1;
}
} else {
u64 rx = -and; u64 id = ~rx;
u64 o = 0; // Saved bits
for (usz i=0, j=0; i<n; i++) {
u64 in = (i+1)*m;
usz jn = in/64;
u64 e = xp[jn] ^ rx;
u64 m = ~(u64)0 << (in%64);
u64 rb = 1;
if ((o | (e &~ m)) == 0) {
rb = 0;
for (usz i = j; i < jn-1; i++) if (xp[i]^id) { rb = 1; break; }
}
o = e & m;
ADDBIT(i, rb, rx);
j = jn+1;
}
rw ^= rx;
}
usz q=(-n)%64; if (q) *rp = rw >> q;
#undef ADDBIT
}
B fold_rows_bit(Md1D* fd, B x) {
assert(isArr(x) && RNK(x)==2 && TI(x,elType)==el_bit);
if (!v(fd->f)->flags) return bi_N;
u8 rtid = v(fd->f)->flags-1;
if (rtid==n_add) return sum_rows_bit(x);
#if SINGELI
if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) {
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
if (m <= 64) return bi_N;
u64* xp = bitarr_ptr(x);
u64* rp; B r = m_bitarrv(&rp, n);
if (rtid==n_ne|rtid==n_eq) xor_rows_bit(xp, rp, n, m, rtid==n_eq);
else or_rows_bit(xp, rp, n, m, rtid==n_and);
if (rtid==n_ne|rtid==n_eq) si_xor_rows_bit(xp, rp, n, m, rtid==n_eq);
else si_or_rows_bit(xp, rp, n, m, rtid==n_and);
decG(x); return r;
}
#endif
return bi_N;
}

View File

@ -26,6 +26,8 @@ def elwidth{T} = width{eltype{T}}
oper &~ andnot infix none 35
def andnot{a, b if anyNum{a} and anyNum{b}} = a & ~b
oper &- ({v:T,m:(u1)} => v & -promote{T,m}) infix left 35
def load {p:*[_]_, n } = assert{0}

View File

@ -70,3 +70,91 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
else extract{mix{op, r}, 0}
}
export{'si_sum_f64', fold_assoc_0{f64,+}}
fn xor_words(init:u64, x:*u64, l:usz):u64 = {
@for (x over l) init ^= x
init
}
def bit_output{rp:*T} = {
buf:u64 = 0 # Buffer for result bits
def output{i, bit, mod} = {
buf = buf>>1 | promote{u64, bit}<<63
if ((i+1)%64==0) { store{rp, 0, mod{buf}}; ++rp }
}
def fixbuf{mod} = { buf = mod{buf} }
def flush_bits{n, mod} = {
q:=(-n)%64; if (q!=0) store{rp, 0, mod{buf} >> q}
}
def flush_bits{n} = flush_bits{n, {b}=>b}
tup{output, fixbuf, flush_bits}
}
# word and alignment of start of next row
def next_start{i, m} = {
bn := promote{u64, i+1} * promote{u64, m}
tup{bn/64, bn%64}
}
fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, eq:u1) : void = {
def p64 = promote{u64, .}
def fixout = ^{-(p64{eq} &~ p64{m}), .} # ne to eq conversion
def {add_bit, _, flush_bits} = bit_output{rp}
def add_bit{i, bit} = add_bit{i, bit, fixout}
def xor_loop{len} = {
o:u64 = 0 # Carry
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, m}
s := xor_words(o, xp + j, len)
e := load{xp, jn}
s ^= e & (if (not same{l,1}) -p64{jn >= j + p64{l}} else j - jn)
o = e >> sh
add_bit{i, popc{s ^ o}}
j = jn+1
}
}
l := m/64
if (m < 128) xor_loop{1}
else if (m%64==0) {
@for (i to n) add_bit{i, popc{xor_words(0, xp+l*i, l)}}
}
else xor_loop{l}
flush_bits{n, fixout}
}
fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = {
def p64 = promote{u64, .}
def {add_bit, set_out, flush_bits} = bit_output{rp}
if (m < 128) {
c:u64 = (p64{m}-1) &- op_and # a row gives 1 if its sum is >c
o:u64 = 0
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, m}
s := o + popc{load{xp,j}}
e := load{xp,jn}
s += popc{e & (j - jn)} # mask is 0 if j==jn, or -1
o = popc{e >> sh}
add_bit{i, s > c+o, {rw}=>rw}
j = jn+1
}
} else {
rx := -promote{u64, op_and}; id := ~rx
def fixout = ^{rx, .}
o:u64 = 0 # Saved bits
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, m}
e := load{xp,jn} ^ rx
m := ~(u64~~0) << sh
rb:u64 = 1
if ((o | (e &~ m)) == 0) { # Search for shortcut
@for (i from j to jn-1) if (load{xp,i} != id) goto{'found'}
rb = 0; setlabel{'found'}
}
o = e & m
add_bit{i, rb, fixout}
j = jn+1
}
set_out{fixout}
}
flush_bits{n}
}
export{'si_xor_rows_bit', xor_rows_bit}
export{'si_or_rows_bit', or_rows_bit}

View File

@ -335,8 +335,6 @@ exportT{'simd_getRangeRaw', each{getRange, tup{i8,i16,i32,f64}}}
# Hash tables
oper &- ({v:T,m} => v & -promote{T,m}) infix left 35
def rty{name} = if (to_prim{name}=='∊') i8 else i32
def ity{name} = (to_prim{name}=='⊒')**(*u32)
fn hashtab{T, name}(rpi:*rty{name}, iv:*void, mi:usz, fv:*void, ni:usz, links:ity{name}) = {