Handle boolean Compress SIMD tail with conditional write, not scalar loop
This commit is contained in:
parent
07ace41d6c
commit
9046dd6b53
@ -345,18 +345,13 @@ def pext_popc{x:T, m:T} = {
|
||||
tup{pe, scal{w} - z}
|
||||
}
|
||||
|
||||
def pext_popc{xs:T, ms:T & hasarch{'PCLMUL'} & width{T}<=128} = {
|
||||
def vt = isvec{T}
|
||||
def V = if (vt) T else [2]T
|
||||
def vec{s} = if (vt) s else make{V, s, 0}
|
||||
def clmul{a, b} = {
|
||||
if (vt) zipLo{...@collect (j to 2) clmul{a,b,j}} else clmul{a, b, 0}
|
||||
}
|
||||
m := vec{ms}
|
||||
x := vec{xs} & m
|
||||
def pext_popc{x0:V, m0:V & hasarch{'PCLMUL'} & V==[2]u64} = {
|
||||
def clmul{a, b} = zipLo{...@collect (j to 2) clmul{a,b,j}}
|
||||
m := m0
|
||||
x := x0 & m
|
||||
d := ~m << 1 # One bit of the position difference at x
|
||||
c := V**(1<<64-1)
|
||||
@unroll (i to lb{scalwidth{T}}) {
|
||||
@unroll (i to lb{scalwidth{V}}) {
|
||||
def sh = 1 << i
|
||||
def shift_at{v, s} = { v = (v&~s) | (v&s)>>sh }
|
||||
p := clmul{d, c} # xor-scan
|
||||
@ -365,8 +360,7 @@ def pext_popc{xs:T, ms:T & hasarch{'PCLMUL'} & width{T}<=128} = {
|
||||
shift_at{m, p}
|
||||
shift_at{x, p}
|
||||
}
|
||||
if (vt) tup{x, @collect (j to 2) popc{extract{ms,j}}}
|
||||
else tup{extract{x, 0}, popc{ms}}
|
||||
tup{x, @collect (j to 2) popc{extract{m0,j}}}
|
||||
}
|
||||
|
||||
def pext_popc{x:T, m:T & hasarch{'BMI2'} & T==u64} = tup{pext{x, m}, popc{m}}
|
||||
@ -387,12 +381,19 @@ fn compress_bool(w:*u64, x:*u64, r:*u64, n:u64) : void = {
|
||||
if (hasarch{'PCLMUL'} or hasarch{'AVX2'}) {
|
||||
def v = if (hasarch{'AVX2'}) 4 else 2
|
||||
def V = [v]u64
|
||||
nv := n/(v*64)
|
||||
@for (w in *V~~w, x in *V~~x over i to nv) {
|
||||
d := cdiv{n,64}; e := d/v
|
||||
@for (w in *V~~w, x in *V~~x over i to cdiv{d,v}) {
|
||||
vc := pext_popc{x, w}
|
||||
@unroll (j to v) add_bits{each{extract{., j}, vc}}
|
||||
def add{j} = add_bits{each{extract{., j}, vc}}
|
||||
if (i < e) {
|
||||
@unroll (j to v) add{j}
|
||||
} else {
|
||||
# last write: between 1 and v-1 words
|
||||
m := d%v
|
||||
def ar{j} = { add{j}; def jn=j+1; if (jn<v-1 and jn<m) ar{jn} }
|
||||
ar{0}
|
||||
}
|
||||
}
|
||||
@for (w, x over i from nv*v to cdiv{n,64}) add_bits{pext_popc{x, w}}
|
||||
} else {
|
||||
@for (w, x over i to cdiv{n,64}) add_bits{pext_popc{x, w}}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user