From f3e0ea75319f6b3331447ad10bba2faf8e7f1acf Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 28 Feb 2025 14:32:43 -0500 Subject: [PATCH] =?UTF-8?q?Use=20CLMUL-based=20=E2=89=A0`=20for=20power-of?= =?UTF-8?q?-two=20stride=20up=20to=2064?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/singeli/src/scan.singeli | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index a3460b96..30cb3fce 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -106,13 +106,13 @@ def vec_prefix_byshift{op, sh} = { def scan_word_ne = prefix_byshift{^, <<} def scan_words_ne = vec_prefix_byshift{^, <<} -fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{}(c:u64, x:*u64, r:*u64, nw:usz) : void = { @for (x, r over nw) { r = c ^ scan_word_ne{x} c = -(r>>63) # repeat sign bit } } -fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:usz) : void = { def vl = arch_defvw / 64 def V = [vl]u64 c := V**c0 @@ -123,7 +123,7 @@ fn scan_neq{if has_simd}(c0:u64, x:*u64, r:*u64, nw:u64) : void = { c = broadcast_last{p} } } -fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = { +fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:usz, mark:u64) : void = { def V = [2]u64 m := V**mark def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total @@ -144,10 +144,10 @@ fn clmul_scan_ne_any{if hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64 store{*u64~~(rv+e), clmul{load{V, *u64~~(xv+e), 1}, m, 0} ^ c, 1} } } -fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{if hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = { clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1)) } -fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = { +fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:usz) : void = { def V = [8]u64 def sse{a} = make{[2]u64, a, 0} carry := sse{init} @@ -561,7 +561,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { c:u64 = 0 # carry while (1) { i+= l; ii := iw; iw = cdiv{i, 64} - scan_neq{}(c, x+ii, r+ii, promote{u64,iw-ii}) + scan_neq{}(c, x+ii, r+ii, iw-ii) if (i == nl) return{} s:= load{r, iw-1} q := i%64 @@ -628,6 +628,10 @@ fn scan_stride_bool_assoc{op}(x:*u64, r:*u64, nl:usz, l:usz) : void = { def {flip,opf} = if (same{op, &}) tup{~,|} else tup{{x}=>x,op} nw:= cdiv{nl, 64} if (l <= 64) { + if (same{op, ^} and hasarch{'PCLMUL'} and (l & (l-1)) == 0) { + clmul_scan_ne_any{}(*void~~x, *void~~r, 0, nw, aligned_spaced_mask{l}) + return{} + } c:u64 = 0 # carry l bits, no matter the alignment @for (r, x over nw) { c = opf{flip{x}, c >> (64-l)}