Medium-row ⊣˘ and ≠˘ with mask-aware carry

This commit is contained in:
Marshall Lochbaum 2025-02-24 09:37:20 -05:00
parent 1fef51a39e
commit 0973d9b499

View File

@ -337,6 +337,34 @@ def avx2_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = {
m = m>>d4 | m<<(l-d4) m = m>>d4 | m<<(l-d4)
} }
} }
def avx2_loop_with_loose_mask{xp, rp, nw, l, id, scan_words, propagate, fix_carry, apply_carry} = {
assert{l >= 64}
def V = [4]u64
c := V**id # carry, 0 or 1
q := -make{V, 64*iota{4}} # distance to next row boundary
def q_mod{} = { q = blend_top{q,q+V**l, q} }
o:u64 = 256; while (o>l) { o-=l; q_mod{} }
@for_masked{4} (x in tup{V, xp}, r in tup{V, rp} over nw) {
# Get mask; <=1 bit per word
m:= V**1 << q
q-= V**o; q_mod{}
# Within-word scan and carry info
ml:= m - V**1
{s, k}:= scan_words{x, m, ml}
# Propagate carries and adjust result
p:= propagate{k, c}
t:= blend{shuf{V, p, 3,0,1,2}, c, 1,0,0,0}
r = apply_carry{s, -fix_carry{t}, ml}
c = shuf{V, p, 4**3}
}
}
def avx2_loop_with_loose_mask{...a={xp, rp, nw, l, id, scan_words}, apply_carry} = {
def passthrough{k, c} = {
def bl{b,a} = blend_top{b,a, b}
bl{make_scan_idem{f64, bl}{k}, c} # Can't be -1 now
}
avx2_loop_with_loose_mask{...a, passthrough, {k}=>k, apply_carry}
}
fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
def qand = not id def qand = not id
@ -375,31 +403,17 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
a >> 63} # new c a >> 63} # new c
}} }}
} }
} else if (l < 160) { } else if (l < (if (hasarch{'AVX2'}) 256 else 160)) {
if (hasarch{'AVX2'}) { if (hasarch{'AVX2'}) {
assert{l >= 64} def scan_words{x:V, m:V, _} = {
def V = [4]u64 s:= (if (qand) x &~ ((x+V**1) & (x+m))
c := V**id # carry, 0 or 1 else x | ((-x) &~ (x-m)))
v1:= V**1; xk:= c - v1 p:= (if (qand) x&~m == ~V**0
q := -make{V, 64*iota{4}} # distance to next row boundary else x| m == V**0)
def q_mod{} = { q = blend_top{q,q+V**l, q} } k:= s>>63 | p # Carry of 0 or 1, but -1 to propagate previous
o:u64 = 256; while (o>l) { o-=l; q_mod{} } tup{s, k}
@for_masked{4} (x in tup{V, src}, r in tup{V, dst} over nw) {
# Get mask; <=1 bit per word
m:= v1 << q
q-= V**o; q_mod{}
# Within-word scan and carry info
r = (if (qand) x &~ ((x+v1) & (x+m))
else x | ((-x) &~ (x-m)))
p:= (if (qand) x&~m else x|m) == xk
k:= r>>63 | p # Carry of 0 or 1, but -1 to propagate previous
# Propagate carries and adjust result
def bl{b,a} = blend_top{b,a, b}
k = bl{make_scan_idem{f64, bl}{k}, c} # Can't be -1 now
t:= blend{shuf{V, k, 3,0,1,2}, c, 1,0,0,0}
r = apply_carry{r, -t, m-v1}
c = shuf{V, k, 4**3}
} }
avx2_loop_with_loose_mask{src, dst, nw, l, id, scan_words, apply_carry}
} else { } else {
q:usz = 0 # distance to next row boundary q:usz = 0 # distance to next row boundary
c:u64 = id # carry c:u64 = id # carry
@ -443,6 +457,13 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
def scan_word = prefix_byshift{^, <<} def scan_word = prefix_byshift{^, <<}
def scan_words = {
def vec_prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, 1}
}
vec_prefix_byshift{^, <<}
}
assert{l > 0} assert{l > 0}
nw := cdiv{nl, 64} nw := cdiv{nl, 64}
if (l < 64) { if (l < 64) {
@ -455,11 +476,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
} }
} else if (hasarch{'AVX2'}) { } else if (hasarch{'AVX2'}) {
def scan_words{x, m} = { def scan_words{x, m} = {
def vec_prefix_byshift{op, sh} = { s:= scan_words{x}
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, 1}
}
s:= vec_prefix_byshift{^, <<}{x}
b:= s<<1 & m # last bit of previous row b:= s<<1 & m # last bit of previous row
s ^ (b<<l - b) s ^ (b<<l - b)
} }
@ -473,6 +490,22 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
s ^ ((c & f) | (b<<l - b)) s ^ ((c & f) | (b<<l - b))
}} }}
} }
} else if (l < 320 and hasarch{'AVX2'}) {
def scan_words{x:V, m:V, ml:V} = {
s:= scan_words{x}
s^= -(s<<1 & m)
k:= s>>63 | (V**(1<<63) &~ ml) # Top bit 1 to stop, so 0 is identity
tup{s, k}
}
def propagate{k:V, c:V} = {
def bl{b,a} = blend_top{a^b,b, b}
k = bl{k, vec_shift_right_128{k, 1}}
k = bl{k, shuf{V, blend{V**0, k, 0,1,0,1}, 0,0,1,1}}
bl{k, c}
}
def fix_carry{t:V} = t & V**1
def apply_carry{r, a, ml} = r ^ (a & ml)
avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, propagate, fix_carry, apply_carry}
} else { } else {
i :usz = 0 # row bit index i :usz = 0 # row bit index
iw:usz = 0 # starting word iw:usz = 0 # starting word
@ -491,7 +524,6 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
} }
fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = { fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = {
def scan_word = prefix_byshift{^, <<}
assert{l > 0} assert{l > 0}
nw := cdiv{nl, 64} nw := cdiv{nl, 64}
if (l < 64) { if (l < 64) {
@ -509,6 +541,13 @@ fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = {
(c & f) | (b<<l - b) (c & f) | (b<<l - b)
}} }}
} }
} else if (l < 176 and hasarch{'AVX2'}) {
def scan_words{x:V, m:V, _} = {
s:= -(x & m)
tup{s, s>>63 | (m == V**0)}
}
def apply_carry{r, a, ml} = r | (a & ml)
avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, apply_carry}
} else { } else {
i :usz = 0 # row bit index i :usz = 0 # row bit index
wn:usz = 0 # starting word of next row wn:usz = 0 # starting word of next row