Fast generic and pext-based ∧˝˘ and ∨˝˘ on width<64

This commit is contained in:
Marshall Lochbaum 2024-06-13 16:52:52 -04:00
parent 40bf3bfd1c
commit e6f1e04de2
2 changed files with 148 additions and 12 deletions

View File

@ -14,6 +14,7 @@
#include "../core.h"
#include "../builtins.h"
#include "../utils/mut.h"
#include "../utils/calls.h"
#if SINGELI
#define SINGELI_FILE fold
@ -512,12 +513,23 @@ B fold_rows_bit(Md1D* fd, B x) {
if (rtid==n_add) return sum_rows_bit(x);
#if SINGELI
if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) {
bool andor = rtid==n_or|rtid==n_and;
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
if (m <= 64) return bi_N;
if (andor && m < 256) while (m%8 == 0) {
usz f = CTZ(m|32);
m >>= f; usz c = m*n;
u64* yp; B y = m_bitarrv(&yp, c);
u8 e = el_i8 + f-3;
CmpASFn cmp = rtid==n_or ? CMP_AS_FN(ne, e) : CMP_AS_FN(eq, e);
CMP_AS_CALL(cmp, yp, bitarr_ptr(x), m_f64((rtid==n_or)-1), c);
decG(x); if (m==1) return y;
x = y;
}
if (!andor && m <= 64) return bi_N;
u64* xp = bitarr_ptr(x);
u64* rp; B r = m_bitarrv(&rp, n);
if (rtid==n_ne|rtid==n_eq) si_xor_rows_bit(xp, rp, n, m, rtid==n_eq);
else si_or_rows_bit(xp, rp, n, m, rtid==n_and);
if (andor) si_or_rows_bit(xp, rp, n, m, rtid==n_and);
else si_xor_rows_bit(xp, rp, n, m, rtid==n_eq);
decG(x); return r;
}
#endif

View File

@ -1,5 +1,8 @@
include './base'
include './mask'
if_inline (hasarch{'BMI2'}) include './bmi2'
include 'util/tup'
def opsh64{op}{v:([4]f64), perm} = op{v, shuf{[4]u64, v, perm}}
def opsh32{op}{v:([2]f64), perm} = op{v, shuf{[4]u32, v, perm}}
@ -120,14 +123,135 @@ fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, eq:u1) : void = {
flush_bits{n, fixout}
}
fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = {
def p64 = promote{u64, .}
def unaligned_mask{l} = {
def d = 64 % l
def m = (~u64~~0 >> d) / ((u64~~1 << l)-1)
tup{m<<l | 1, d}
}
def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
nw := cdiv{n*l,64}
xx := -promote{u64, op_and}
def loop_pext{T, get, b} = {
@for (x in xp, r in *T~~rp over nw) r = cast_i{T, pext{get{x}, b}}
}
if (l == 2) {
t:u64 = 64w2b10
def loop = {
if (fast_BMI2{}) {
{op} => loop_pext{u32, {x}=>op{x, x<<1}, t}
} else {
def all = 1<<64-1
ms:5**u64 = reverse{slice{scan{{x,s}=>x^((x<<s)&all), all, reverse{1<<range{6}}}, 1}}
{op} => {
def T = u32
@for (x in xp, r in *T~~rp over nw) {
a := op{x, x>>1}
each{{m, sh} => { a &= m; a |= a>>sh }, ms, 1<<range{5}}
r = cast_i{T, a}
}
}
}
}
if (op_and) loop{&} else loop{|}
} else if (l == 4) {
m:u64 = 64w2b0001; t := m<<3
def loop = {
if (fast_BMI2{}) {
loop_pext{u16, ., t}
} else {
{get} => {
def T = u16
@for (x in xp, r in *T~~rp over nw) {
a := get{x} & t
a = (a * 4r2b001) & (64w0xf000)
a = (a * fold{+, 1<<(12*iota{4})}) >> (3*16)
r = cast_i{T, a}
}
}
}
}
if (op_and) loop{{x} => x & ((x&~t) + m)}
else loop{{x} => x | ((x| t) - m)}
} else { # odd row length
d:usz = 64 % l
e := ((~u64~~0 >> d) / ((u64~~1 << l)-1)) << (l-1)
r:u64 = 0
rh := *u32~~rp
ri:ux = 0
if (fast_BMI2{}) {
@for (xo in xp over i to nw) {
m := e<<1 | 1
t := e | 1<<63
x := xo^xx
r |= pext{x | ((x|t) - m), t} << ri
ri += popc{e}
e = e>>d | e<<(l-d)
if (ri >= 32) {
store{rh, 0, cast_i{u32,r^xx}}; ++rh;
r >>= 32; ri -= 32
}
}
} else {
dm:= 64/l
e0:= e<<1 | 1
c:u64 = 0
def loop{...par} = {
@for (xo in xp over nw) {
m := e<<1 | 1
t := e | (1<<63)
x := xo^xx
s := x | ((x|t) - m) # Fold results
cs:= c; c = (s&~e) >> 63
nb:= dm + promote{usz, e>=e0} # = popc{e}
def extract = match {
{{...qs, q}, {...bs, b}} => (extract{qs, bs} & b) * q
{ {q}, {}} => (s & e) * (q << clz{e})
}
rb:= extract{...par} >> (64 - nb)
r |= (rb|cs) << ri
ri += promote{ux, nb}
if (ri >= 32) {
store{rh, 0, cast_i{u32,r^xx}}; ++rh;
r >>= 32; ri -= 32
}
e = e>>d | e<<(l-d)
}
}
if (l == 3) {
mult0:u64 = base{1<< 2, 3**1}; top3:u64 = base{1<<9, 8**(1<<3-1)}>>2
mult1:u64 = base{1<< 6, 3**1}; top9:u64 = base{1<<27, 3**(1<<9-1)}<<1
mult2:u64 = base{1<<18, 3**1}
loop{tup{mult0,mult1,mult2}, tup{top3, top9}}
} else if (l < 8) {
assert{l > 4}
ld:= l-1; lld:= l*ld
{mult0, _} := unaligned_mask{ld}
mult0 &= u64~~1<<lld - 1
{mult1, _} := unaligned_mask{lld}
ll:= l*l
{tk, tkd} := unaligned_mask{ll}; tk <<= tkd
topk := tk - tk>>l; topk|= topk<<ll | topk>>ll
loop{tup{mult0,mult1}, tup{topk}}
} else {
{mult, _} := unaligned_mask{l-1}
loop{tup{mult}, tup{}}
}
}
if (ri > 0) store{rh, 0, cast_i{u32,r^xx}}
}
}
fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
def {add_bit, set_out, flush_bits} = bit_output{rp}
if (m < 128) {
c:u64 = (p64{m}-1) &- op_and # a row gives 1 if its sum is >c
if (l < 64) {
or_rows_bit_lt64{xp, rp, n, l, op_and}
return{}
} else if (l < 128) {
c:u64 = (promote{u64, l}-1) &- op_and # a row gives 1 if its sum is >c
o:u64 = 0
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, m}
def {jn, sh} = next_start{i, l}
s := o + popc{load{xp,j}}
e := load{xp,jn}
s += popc{e & (j - jn)} # mask is 0 if j==jn, or -1
@ -140,15 +264,15 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = {
def fixout = ^{rx, .}
o:u64 = 0 # Saved bits
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, m}
def {jn, sh} = next_start{i, l}
e := load{xp,jn} ^ rx
m := ~(u64~~0) << sh
l := ~(u64~~0) << sh
rb:u64 = 1
if ((o | (e &~ m)) == 0) { # Search for shortcut
if ((o | (e &~ l)) == 0) { # Search for shortcut
@for (i from j to jn-1) if (load{xp,i} != id) goto{'found'}
rb = 0; setlabel{'found'}
}
o = e & m
o = e & l
add_bit{i, rb, fixout}
j = jn+1
}