Factor out some common code

This commit is contained in:
Marshall Lochbaum 2025-02-24 15:01:02 -05:00
parent 0973d9b499
commit b131914cbe

View File

@ -464,23 +464,19 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
}
vec_prefix_byshift{^, <<}
}
def apply_carry{s, c, f} = s ^ (f & c)
assert{l > 0}
nw := cdiv{nl, 64}
if (l < 64) {
def apply_mask{s, m} = {
b:= s<<1 & m # last bit of previous row
s ^ (b<<l - b)
}
if ((l & (l-1)) == 0) {
m:u64 = aligned_spaced_mask{l}
@for (r, x over nw) {
s:= scan_word{x}
b:= s<<1 & m # last bit of previous row
r = s ^ (b<<l - b)
}
@for (r, x over nw) r = apply_mask{scan_word{x}, m}
} else if (hasarch{'AVX2'}) {
def scan_words{x, m} = {
s:= scan_words{x}
b:= s<<1 & m # last bit of previous row
s ^ (b<<l - b)
}
def apply_carry{s, c, f} = s ^ (f & c)
def scan_words{x, m} = apply_mask{scan_words{x}, m}
avx2_loop_with_unaligned_mask{x, r, nw, l, scan_words, apply_carry}
} else {
loop_with_unaligned_mask{x, r, nw, l, {x, c, m} => {
@ -504,7 +500,6 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
bl{k, c}
}
def fix_carry{t:V} = t & V**1
def apply_carry{r, a, ml} = r ^ (a & ml)
avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, propagate, fix_carry, apply_carry}
} else {
i :usz = 0 # row bit index
@ -526,19 +521,18 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = {
assert{l > 0}
nw := cdiv{nl, 64}
def apply_carry{s, c, f} = s | (f & c)
if (l < 64) {
def apply_mask{x, m} = { b:= x & m; b<<l - b }
if ((l & (l-1)) == 0) {
m:u64 = aligned_spaced_mask{l}
@for (r, x over nw) { b:= x & m; r = b<<l - b }
@for (r, x over nw) r = apply_mask{x, m}
} else if (hasarch{'AVX2'}) {
def scan_words{x, m} = { b:= x&m; b<<l - b }
def apply_carry{s, c, f} = s | (f & c)
avx2_loop_with_unaligned_mask{x, r, nw, l, scan_words, apply_carry}
avx2_loop_with_unaligned_mask{x, r, nw, l, apply_mask, apply_carry}
} else {
loop_with_unaligned_mask{x, r, nw, l, {x, c, m} => {
f:= (m-1)&~m # bits before first full row
b:= x & m
(c & f) | (b<<l - b)
(c & f) | apply_mask{x, m}
}}
}
} else if (l < 176 and hasarch{'AVX2'}) {
@ -546,7 +540,6 @@ fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = {
s:= -(x & m)
tup{s, s>>63 | (m == V**0)}
}
def apply_carry{r, a, ml} = r | (a & ml)
avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, apply_carry}
} else {
i :usz = 0 # row bit index