Fast generic and pext-based ∧˝˘ and ∨˝˘ on width<64
This commit is contained in:
parent
40bf3bfd1c
commit
e6f1e04de2
@ -14,6 +14,7 @@
|
||||
#include "../core.h"
|
||||
#include "../builtins.h"
|
||||
#include "../utils/mut.h"
|
||||
#include "../utils/calls.h"
|
||||
|
||||
#if SINGELI
|
||||
#define SINGELI_FILE fold
|
||||
@ -512,12 +513,23 @@ B fold_rows_bit(Md1D* fd, B x) {
|
||||
if (rtid==n_add) return sum_rows_bit(x);
|
||||
#if SINGELI
|
||||
if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) {
|
||||
bool andor = rtid==n_or|rtid==n_and;
|
||||
usz *sh = SH(x); usz n = sh[0]; usz m = sh[1];
|
||||
if (m <= 64) return bi_N;
|
||||
if (andor && m < 256) while (m%8 == 0) {
|
||||
usz f = CTZ(m|32);
|
||||
m >>= f; usz c = m*n;
|
||||
u64* yp; B y = m_bitarrv(&yp, c);
|
||||
u8 e = el_i8 + f-3;
|
||||
CmpASFn cmp = rtid==n_or ? CMP_AS_FN(ne, e) : CMP_AS_FN(eq, e);
|
||||
CMP_AS_CALL(cmp, yp, bitarr_ptr(x), m_f64((rtid==n_or)-1), c);
|
||||
decG(x); if (m==1) return y;
|
||||
x = y;
|
||||
}
|
||||
if (!andor && m <= 64) return bi_N;
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64* rp; B r = m_bitarrv(&rp, n);
|
||||
if (rtid==n_ne|rtid==n_eq) si_xor_rows_bit(xp, rp, n, m, rtid==n_eq);
|
||||
else si_or_rows_bit(xp, rp, n, m, rtid==n_and);
|
||||
if (andor) si_or_rows_bit(xp, rp, n, m, rtid==n_and);
|
||||
else si_xor_rows_bit(xp, rp, n, m, rtid==n_eq);
|
||||
decG(x); return r;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
include './base'
|
||||
include './mask'
|
||||
if_inline (hasarch{'BMI2'}) include './bmi2'
|
||||
|
||||
include 'util/tup'
|
||||
|
||||
def opsh64{op}{v:([4]f64), perm} = op{v, shuf{[4]u64, v, perm}}
|
||||
def opsh32{op}{v:([2]f64), perm} = op{v, shuf{[4]u32, v, perm}}
|
||||
@ -120,14 +123,135 @@ fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, eq:u1) : void = {
|
||||
flush_bits{n, fixout}
|
||||
}
|
||||
|
||||
fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = {
|
||||
def p64 = promote{u64, .}
|
||||
def unaligned_mask{l} = {
|
||||
def d = 64 % l
|
||||
def m = (~u64~~0 >> d) / ((u64~~1 << l)-1)
|
||||
tup{m<<l | 1, d}
|
||||
}
|
||||
|
||||
def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
|
||||
nw := cdiv{n*l,64}
|
||||
xx := -promote{u64, op_and}
|
||||
def loop_pext{T, get, b} = {
|
||||
@for (x in xp, r in *T~~rp over nw) r = cast_i{T, pext{get{x}, b}}
|
||||
}
|
||||
if (l == 2) {
|
||||
t:u64 = 64w2b10
|
||||
def loop = {
|
||||
if (fast_BMI2{}) {
|
||||
{op} => loop_pext{u32, {x}=>op{x, x<<1}, t}
|
||||
} else {
|
||||
def all = 1<<64-1
|
||||
ms:5**u64 = reverse{slice{scan{{x,s}=>x^((x<<s)&all), all, reverse{1<<range{6}}}, 1}}
|
||||
{op} => {
|
||||
def T = u32
|
||||
@for (x in xp, r in *T~~rp over nw) {
|
||||
a := op{x, x>>1}
|
||||
each{{m, sh} => { a &= m; a |= a>>sh }, ms, 1<<range{5}}
|
||||
r = cast_i{T, a}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (op_and) loop{&} else loop{|}
|
||||
} else if (l == 4) {
|
||||
m:u64 = 64w2b0001; t := m<<3
|
||||
def loop = {
|
||||
if (fast_BMI2{}) {
|
||||
loop_pext{u16, ., t}
|
||||
} else {
|
||||
{get} => {
|
||||
def T = u16
|
||||
@for (x in xp, r in *T~~rp over nw) {
|
||||
a := get{x} & t
|
||||
a = (a * 4r2b001) & (64w0xf000)
|
||||
a = (a * fold{+, 1<<(12*iota{4})}) >> (3*16)
|
||||
r = cast_i{T, a}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (op_and) loop{{x} => x & ((x&~t) + m)}
|
||||
else loop{{x} => x | ((x| t) - m)}
|
||||
} else { # odd row length
|
||||
d:usz = 64 % l
|
||||
e := ((~u64~~0 >> d) / ((u64~~1 << l)-1)) << (l-1)
|
||||
r:u64 = 0
|
||||
rh := *u32~~rp
|
||||
ri:ux = 0
|
||||
if (fast_BMI2{}) {
|
||||
@for (xo in xp over i to nw) {
|
||||
m := e<<1 | 1
|
||||
t := e | 1<<63
|
||||
x := xo^xx
|
||||
r |= pext{x | ((x|t) - m), t} << ri
|
||||
ri += popc{e}
|
||||
e = e>>d | e<<(l-d)
|
||||
if (ri >= 32) {
|
||||
store{rh, 0, cast_i{u32,r^xx}}; ++rh;
|
||||
r >>= 32; ri -= 32
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dm:= 64/l
|
||||
e0:= e<<1 | 1
|
||||
c:u64 = 0
|
||||
def loop{...par} = {
|
||||
@for (xo in xp over nw) {
|
||||
m := e<<1 | 1
|
||||
t := e | (1<<63)
|
||||
x := xo^xx
|
||||
s := x | ((x|t) - m) # Fold results
|
||||
cs:= c; c = (s&~e) >> 63
|
||||
nb:= dm + promote{usz, e>=e0} # = popc{e}
|
||||
def extract = match {
|
||||
{{...qs, q}, {...bs, b}} => (extract{qs, bs} & b) * q
|
||||
{ {q}, {}} => (s & e) * (q << clz{e})
|
||||
}
|
||||
rb:= extract{...par} >> (64 - nb)
|
||||
r |= (rb|cs) << ri
|
||||
ri += promote{ux, nb}
|
||||
if (ri >= 32) {
|
||||
store{rh, 0, cast_i{u32,r^xx}}; ++rh;
|
||||
r >>= 32; ri -= 32
|
||||
}
|
||||
e = e>>d | e<<(l-d)
|
||||
}
|
||||
}
|
||||
if (l == 3) {
|
||||
mult0:u64 = base{1<< 2, 3**1}; top3:u64 = base{1<<9, 8**(1<<3-1)}>>2
|
||||
mult1:u64 = base{1<< 6, 3**1}; top9:u64 = base{1<<27, 3**(1<<9-1)}<<1
|
||||
mult2:u64 = base{1<<18, 3**1}
|
||||
loop{tup{mult0,mult1,mult2}, tup{top3, top9}}
|
||||
} else if (l < 8) {
|
||||
assert{l > 4}
|
||||
ld:= l-1; lld:= l*ld
|
||||
{mult0, _} := unaligned_mask{ld}
|
||||
mult0 &= u64~~1<<lld - 1
|
||||
{mult1, _} := unaligned_mask{lld}
|
||||
ll:= l*l
|
||||
{tk, tkd} := unaligned_mask{ll}; tk <<= tkd
|
||||
topk := tk - tk>>l; topk|= topk<<ll | topk>>ll
|
||||
loop{tup{mult0,mult1}, tup{topk}}
|
||||
} else {
|
||||
{mult, _} := unaligned_mask{l-1}
|
||||
loop{tup{mult}, tup{}}
|
||||
}
|
||||
}
|
||||
if (ri > 0) store{rh, 0, cast_i{u32,r^xx}}
|
||||
}
|
||||
}
|
||||
|
||||
fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
|
||||
def {add_bit, set_out, flush_bits} = bit_output{rp}
|
||||
if (m < 128) {
|
||||
c:u64 = (p64{m}-1) &- op_and # a row gives 1 if its sum is >c
|
||||
if (l < 64) {
|
||||
or_rows_bit_lt64{xp, rp, n, l, op_and}
|
||||
return{}
|
||||
} else if (l < 128) {
|
||||
c:u64 = (promote{u64, l}-1) &- op_and # a row gives 1 if its sum is >c
|
||||
o:u64 = 0
|
||||
j:u64 = 0; @for (i to n) {
|
||||
def {jn, sh} = next_start{i, m}
|
||||
def {jn, sh} = next_start{i, l}
|
||||
s := o + popc{load{xp,j}}
|
||||
e := load{xp,jn}
|
||||
s += popc{e & (j - jn)} # mask is 0 if j==jn, or -1
|
||||
@ -140,15 +264,15 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = {
|
||||
def fixout = ^{rx, .}
|
||||
o:u64 = 0 # Saved bits
|
||||
j:u64 = 0; @for (i to n) {
|
||||
def {jn, sh} = next_start{i, m}
|
||||
def {jn, sh} = next_start{i, l}
|
||||
e := load{xp,jn} ^ rx
|
||||
m := ~(u64~~0) << sh
|
||||
l := ~(u64~~0) << sh
|
||||
rb:u64 = 1
|
||||
if ((o | (e &~ m)) == 0) { # Search for shortcut
|
||||
if ((o | (e &~ l)) == 0) { # Search for shortcut
|
||||
@for (i from j to jn-1) if (load{xp,i} != id) goto{'found'}
|
||||
rb = 0; setlabel{'found'}
|
||||
}
|
||||
o = e & m
|
||||
o = e & l
|
||||
add_bit{i, rb, fixout}
|
||||
j = jn+1
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user