AVX2 k/bool for odd 64<k<256 using shift-based masks

This commit is contained in:
Marshall Lochbaum 2025-03-07 15:26:50 -05:00
parent bd942894bf
commit 302c2f926f
2 changed files with 66 additions and 6 deletions

View File

@ -45,6 +45,7 @@
// COULD use pdep or similar to avoid overhead on small results
// Otherwise, factor into power of 2 times odd
// COULD fuse 2×odd, since 2/odd/ has a larger intermediate
// 𝕨≤256, AVX2: Modular permutation with shift-based masks
// Other typed 𝕩 uses +`, or lots of Singeli
// Fixed shuffles, factorization, partial shuffles, self-overlapping
// Otherwise, cell-by-cell copying
@ -779,7 +780,10 @@ B slash_c2(B t, B w, B x) {
if (xl == 0) {
u64* xp = bitany_ptr(x);
u64* rp; r = m_bitarrv(&rp, s);
#if SINGELI
#if SINGELI_AVX2
if (wv <= 256) si_constrep_bool(wv, xp, rp, s);
else
#elif SINGELI
if (wv <= 64) si_constrep_bool(wv, xp, rp, s);
else
#endif

View File

@ -296,7 +296,7 @@ if_inline (hasarch{'AARCH64'}) {
}
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
assert{wv >= 2}; assert{wv <= 64}
assert{wv >= 2}
nw := cdiv{rlen, 64}
if (wv&1 == 0) {
p := ctz{wv | 8} # Power of two for second replicate
@ -555,14 +555,15 @@ def proc_mod_dat{swap_data:W} = {
else tup{8, get_byteperm{}}
}
# Fill in higher steps
def get_mod_permuter{} = {
def get_mod_permuter{width} = {
def get_swap{l} = {
def mask = extract_modperm_mask{swap_data, swap_lane, l}
modperm_shuf_step{., l, mask}
}
def swap_x = on_len_range{get_swap, partwidth, ww}
def swap_x = on_len_range{get_swap, partwidth, width}
{x} => partperm{swap_x{x}}
}
def get_mod_permuter{} = get_mod_permuter{ww}
tup{partperm, get_mod_permuter}
}
@ -577,15 +578,18 @@ def rep_const_bool_odd{k, x, r, nw} = {
# Modular permutation: small-k cases may use a limited permutation
# on bytes or 32-bit ints; general case uses the whole thing
swtab:*W = each{modperm_dat{W, .}, 1+2*iota{32}}
swap_data := load{swtab, k>>1}
swap_data := load{swtab, (k%64)>>1}
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
def sp_max = if (any_sel) 8 else 4
if (k < sp_max) {
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
} else {
} else if (not avx2 or k < 64) {
def get_swap_x = getter{get_full_permute{}}
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}}
} else {
def get_swap_x = getter{get_full_permute{64}}
rep_const_bool_odd_loose_mask{W, k, get_swap_x, output}
}
flush{}
}
@ -599,6 +603,7 @@ def rep_const_bool_odd_mask4{
} = {
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
def scal = ifvec{{v} => M**v}
assert{k < 64}
# Fundamental operation: shifts act as order-k cyclic group on masks
def advance{m, sh} = advance_spaced_mask{k, m, sh}
@ -729,4 +734,55 @@ def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
}
}
# Odd factors larger than 64
# AVX2-only because scalar should be about as good otherwise
def rep_const_bool_odd_loose_mask{V=[vl==4](u64), k, get_modperm_x, output if hasarch{'AVX2'}} = {
assert{k > 64}
# Distance from end to previous row boundary (-k <= q < 0)
q := -make{V, 64*(1+iota{vl})}
def q_mod{} = { d:=q+V**k; q = blend_top{q,d, d} }
o:u64 = width{V}; while (o>k) { o-=k; q_mod{} }
km:= k%64
i:usz = 0; iv:usz = 0 # Words and vectors completed
def step{perm, diff} = {
# Mod-64 mask with 1 bit per word
m:= V**1 << (q & V**63)
# Indicator of which bits are actual boundaries
def S = ty_s{V}; a:= S**(-65) < S~~q
q-= V**o; q_mod{}
# Set to bit from perm, but below a&m xor with diff to get previous
base:= (m & perm) == m
md:= (a & m) & diff
output{base ^ (md + (md==m))}
}
{xp, xd, perm, diff}:= 4**(V**0)
while (1) {
# get_modperm_x permutes each 64-bit word
# Each iteration of this loop handles one permuted word
def first = shuf{., 4**0}
if (i%4 == 0) {
xp = get_modperm_x{}
# Shift by 1, or k%64 in mod-space
# Then the low bit of each word has to be moved to the next
# As before, first bit is wrong but unused
xl:= xp>>(64-km)
xo:= (xp<<km | (xl &~ V**1)) | (shuf{xl, 3,0,1,2} & V**1)
xd = xo ^ xp
perm = first{xp}; diff = first{xd}
} else {
def upd{xq, q} = {
xq = shuf{xq, 1,2,3,0} # Next word
qs:= q; q = first{xq}
blend_hom{q, qs, iota{V} < V**(i%4)}
}
step{upd{xp, perm}, upd{xd, diff}} # Do boundary between iterations
++iv
}
i+= k
ip:= iv; iv = i/4
@for (iv - ip) step{perm, diff}
}
}
export{'si_constrep_bool', rep_const_bool{}}