Implement ≠˝˘ and =˝˘ like ∧∨ on width<64 boolean

This commit is contained in:
Marshall Lochbaum 2024-06-17 09:18:06 -04:00
parent f0f130c42e
commit a8b036ad08
2 changed files with 122 additions and 75 deletions

View File

@ -527,7 +527,6 @@ B fold_rows_bit(Md1D* fd, B x) {
decG(x); if (m==1) return y;
x = y;
}
if (!andor && m <= 64) return bi_N;
u64* xp = bitarr_ptr(x);
u64* rp; B r = m_bitarrv(&rp, n);
if (andor) si_or_rows_bit(xp, rp, n, m, rtid==n_and);

View File

@ -2,8 +2,8 @@ include './base'
include './mask'
if_inline (hasarch{'BMI2'}) include './bmi2'
include './spaced'
include 'util/tup'
include './scan_common'
def opsh64{op}{v:([4]f64), perm} = op{v, shuf{[4]u64, v, perm}}
def opsh32{op}{v:([2]f64), perm} = op{v, shuf{[4]u32, v, perm}}
@ -76,63 +76,17 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
export{'si_sum_f64', fold_assoc_0{f64,+}}
fn xor_words(init:u64, x:*u64, l:usz):u64 = {
@for (x over l) init ^= x
init
}
def bit_output{rp:*T} = {
buf:u64 = 0 # Buffer for result bits
def output{i, bit, mod} = {
buf = buf>>1 | promote{u64, bit}<<63
if ((i+1)%64==0) { store{rp, 0, mod{buf}}; ++rp }
}
def fixbuf{mod} = { buf = mod{buf} }
def flush_bits{n, mod} = {
q:=(-n)%64; if (q!=0) store{rp, 0, mod{buf} >> q}
}
def flush_bits{n} = flush_bits{n, {b}=>b}
tup{output, fixbuf, flush_bits}
}
# word and alignment of start of next row
def next_start{i, m} = {
bn := promote{u64, i+1} * promote{u64, m}
tup{bn/64, bn%64}
}
fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, eq:u1) : void = {
def p64 = promote{u64, .}
def fixout = ^{-(p64{eq} &~ p64{m}), .} # ne to eq conversion
def {add_bit, _, flush_bits} = bit_output{rp}
def add_bit{i, bit} = add_bit{i, bit, fixout}
def xor_loop{len} = {
o:u64 = 0 # Carry
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, m}
s := xor_words(o, xp + j, len)
e := load{xp, jn}
s ^= e & (if (not same{l,1}) -p64{jn >= j + p64{l}} else j - jn)
o = e >> sh
add_bit{i, popc{s ^ o}}
j = jn+1
}
}
l := m/64
if (m < 128) xor_loop{1}
else if (m%64==0) {
@for (i to n) add_bit{i, popc{xor_words(0, xp+l*i, l)}}
}
else xor_loop{l}
flush_bits{n, fixout}
}
def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
def fold_rows_bit_lt64{
op, run_loop2, run_loop4, pext_res, mult_in, xx, rx, rxs,
xp, rp, n, l
} = {
nw := cdiv{n*l,64}
xx := -promote{u64, op_and}
def loop_pext{T, get, b} = {
@for (x in xp, r in *T~~rp over nw) r = cast_i{T, pext{get{x}, b}}
@for (x in xp, r in *T~~rp over nw) r = cast_i{T, rxs{pext{get{x}, b}}}
}
if (l == 2) {
t:u64 = 64w2b10
def loop = {
run_loop2{
if (fast_BMI2{}) {
{op} => loop_pext{u32, {x}=>op{x, x<<1}, t}
} else {
@ -143,15 +97,14 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
@for (x in xp, r in *T~~rp over nw) {
a := op{x, x>>1}
each{{m, sh} => { a &= m; a |= a>>sh }, ms, 1<<range{5}}
r = cast_i{T, a}
r = cast_i{T, rxs{a}}
}
}
}
}
if (op_and) loop{&} else loop{|}
} else if (l == 4) {
m:u64 = 64w2b0001; t := m<<3
def loop = {
run_loop4{m, t,
if (fast_BMI2{}) {
loop_pext{u16, ., t}
} else {
@ -161,13 +114,11 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
a := get{x} & t
a = (a * 4r2b001) & (64w0xf000)
a = (a * fold{+, 1<<(12*iota{4})}) >> (3*16)
r = cast_i{T, a}
r = cast_i{T, rxs{a}}
}
}
}
}
if (op_and) loop{{x} => x & ((x&~t) + m)}
else loop{{x} => x | ((x| t) - m)}
} else {
{e0, d} := unaligned_spaced_mask_mod{l}
e := e0 << (l-1)
@ -175,38 +126,41 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
rh := *u32~~rp
ri:ux = 0
if (fast_BMI2{}) {
@for (xo in xp over i to nw) {
m := e<<1 | 1
t := e | 1<<63
x := xo^xx
r |= pext{x | ((x|t) - m), t} << ri
c:u64 = 0
@for (x in xp over i to nw) {
r |= pext_res{xx{x}, e, c} << ri
ri += popc{e}
e = e>>d | e<<(l-d)
if (ri >= 32) {
store{rh, 0, cast_i{u32,r^xx}}; ++rh;
store{rh, 0, cast_i{u32,rx{r}}}; ++rh;
r >>= 32; ri -= 32
}
}
} else {
dm:= cast_i{usz, popc{e}}
e0 |= -promote{u64, l & (l-1) == 0}
c:u64 = 0
def loop{...par} = {
@for (xo in xp over nw) {
m := e<<1 | 1
t := e | (1<<63)
x := xo^xx
s := x | ((x|t) - m) # Fold results
cs:= c; c = (s&~e) >> 63
x:= xx{xo}
def {s, mod_rb} = if (length{select{par,0}} == 3) {
# fast path for l==3
s:= op{op{x, x<<1}, op{x<<2, c}}
c = op{x>>63, x>>62}
tup{s, {rb}=>rb}
} else {
mult_in{x, e, c}
}
nb:= dm + promote{usz, e>=e0} # = popc{e}
def extract = match {
{{...qs, q}, {...bs, b}} => (extract{qs, bs} & b) * q
{ {q}, {}} => (s & e) * (q << clz{e})
}
rb:= extract{...par} >> (64 - nb)
r |= (rb|cs) << ri
r |= mod_rb{rb} << ri
ri += promote{ux, nb}
if (ri >= 32) {
store{rh, 0, cast_i{u32,r^xx}}; ++rh;
store{rh, 0, cast_i{u32,rx{r}}}; ++rh;
r >>= 32; ri -= 32
}
e = e>>d | e<<(l-d)
@ -229,17 +183,111 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = {
loop{tup{mult0,mult1}, tup{topk}}
} else {
{mult, _} := unaligned_spaced_mask_mod{l-1}
if (l==8) mult &= 1<<(7*8) - 1
loop{tup{mult}, tup{}}
}
}
if (ri > 0) store{rh, 0, cast_i{u32,r^xx}}
if (ri > 0) store{rh, 0, cast_i{u32,rx{r}}}
}
}
fn xor_words(init:u64, x:*u64, l:usz):u64 = {
@for (x over l) init ^= x
init
}
def bit_output{rp:*T} = {
buf:u64 = 0 # Buffer for result bits
def output{i, bit, mod} = {
buf = buf>>1 | promote{u64, bit}<<63
if ((i+1)%64==0) { store{rp, 0, mod{buf}}; ++rp }
}
def fixbuf{mod} = { buf = mod{buf} }
def flush_bits{n, mod} = {
q:=(-n)%64; if (q!=0) store{rp, 0, mod{buf} >> q}
}
def flush_bits{n} = flush_bits{n, {b}=>b}
tup{output, fixbuf, flush_bits}
}
# word and alignment of start of next row
def next_start{i, m} = {
bn := promote{u64, i+1} * promote{u64, m}
tup{bn/64, bn%64}
}
fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, eq:u1) : void = {
def p64 = promote{u64, .}
rx:= -(p64{eq} &~ p64{l}) # ne to eq conversion
if (l <= 64 and not (l%8==0 and l>16)) {
def run_loop2{loop} = loop{^}
def run_loop4{m, t, loop} = loop{{x} => { x2:= x^(x<<1); x2^(x2<<2) }}
def xor_word = prefix_byshift{^, <<}
def pext_in{xo, e, c} = {
x := xor_word{xo}
rb:= x ^ (x<<l | c)
c = (x ^ -(x>>63))>>(64-l)
rb
}
def pext_res{x, e, c} = pext{pext_in{x, e, c}, e}
def mult_in{x, e, c} = tup{pext_in{x, e, c}, {r}=>r}
fold_rows_bit_lt64{
^, run_loop2, run_loop4, pext_res, mult_in, {x}=>x, ^{rx,.}, ^{rx,.},
xp, rp, n, l
}
} else {
def fixout = ^{rx, .}
def {add_bit, _, flush_bits} = bit_output{rp}
def add_bit{i, bit} = add_bit{i, bit, fixout}
def xor_loop{len} = {
o:u64 = 0 # Carry
j:u64 = 0; @for (i to n) {
def {jn, sh} = next_start{i, l}
s := xor_words(o, xp + j, len)
e := load{xp, jn}
s ^= e & (if (not same{len,1}) -p64{jn >= j + p64{len}} else j - jn)
o = e >> sh
add_bit{i, popc{s ^ o}}
j = jn+1
}
}
ll := l/64
if (l <= 64) {
bm:= u64~~2<<(l-1) - 1
k:= l/8
@for (i to n) add_bit{i, popc{bm & load{*u64~~(*u8~~xp + k*i)}}}
} else if (l < 128) xor_loop{1}
else if (l%64==0) {
@for (i to n) add_bit{i, popc{xor_words(0, xp+ll*i, ll)}}
}
else xor_loop{ll}
flush_bits{n, fixout}
}
}
fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
def {add_bit, set_out, flush_bits} = bit_output{rp}
if (l < 64) {
or_rows_bit_lt64{xp, rp, n, l, op_and}
def run_loop2{loop} = if (op_and) loop{&} else loop{|}
def run_loop4{m, t, loop} = {
if (op_and) loop{{x} => x & ((x&~t) + m)}
else loop{{x} => x | ((x| t) - m)}
}
def xor_word = prefix_byshift{^, <<}
def pext_in{x, e, c} = {
m := e<<1 | 1
t := e | 1<<63
tup{x | ((x|t) - m), t}
}
def pext_res{x, e, c} = pext{...pext_in{x, e, c}}
def mult_in{x, e, c} = {
{s, t} := pext_in{x, e, c}
cs:= c; c = (s&~e) >> 63
tup{s, {rb}=>rb|cs}
}
def xx = ^{-promote{u64, op_and}, .}
fold_rows_bit_lt64{
|, run_loop2, run_loop4, pext_res, mult_in, xx, xx, {r}=>r,
xp, rp, n, l
}
return{}
} else if (l < 128) {
c:u64 = (promote{u64, l}-1) &- op_and # a row gives 1 if its sum is >c