Use CLMUL-based ≠` for power-of-two stride up to 64

This commit is contained in:
Marshall Lochbaum 2025-02-28 14:32:43 -05:00
parent ee6b91be8a
commit f3e0ea7531

View File

@ -106,13 +106,13 @@ def vec_prefix_byshift{op, sh} = {
def scan_word_ne = prefix_byshift{^, <<}
def scan_words_ne = vec_prefix_byshift{^, <<}
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = {
fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:usz) : void = {
@for (x, r over nw) {
r = c ^ scan_word_ne{x}
c = -(r>>63) # repeat sign bit
}
}
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:usz) : void = {
def vl = arch_defvw / 64
def V = [vl]u64
c := V**c0
@ -123,7 +123,7 @@ fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = {
c = broadcast_last{p}
}
}
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:usz, mark:u64) : void = {
def V = [2]u64
m := V**mark
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
@ -144,10 +144,10 @@ fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64
store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1}
}
}
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
}
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = {
def V = [8]u64
def sse{a} = make{[2]u64, a, 0}
carry := sse{init}
@ -561,7 +561,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
c:u64 = 0 # carry
while (1) {
i+= l; ii := iw; iw = cdiv{i, 64}
scan_neq{}(c, x+ii, r+ii, promote{u64,iw-ii})
scan_neq{}(c, x+ii, r+ii, iw-ii)
if (i == nl) return{}
s:= load{r, iw-1}
q := i%64
@ -628,6 +628,10 @@ fn scan_stride_bool_assoc{op}(x:*u64, r:*u64, nl:usz, l:usz) : void = {
def {flip,opf} = if (same{op, &}) tup{~,|} else tup{{x}=>x,op}
nw:= cdiv{nl, 64}
if (l <= 64) {
if (same{op, ^} and hasarch{'PCLMUL'} and (l & (l-1)) == 0) {
clmul_scan_ne_any{}(*void~~x, *void~~r, 0, nw, aligned_spaced_mask{l})
return{}
}
c:u64 = 0 # carry l bits, no matter the alignment
@for (r, x over nw) {
c = opf{flip{x}, c >> (64-l)}