Generic-arch k/bool algorithm for odd k with mask unpacking

This commit is contained in:
Marshall Lochbaum 2024-08-08 18:08:49 -04:00
parent 7d7d36b354
commit 1621c7d07c

View File

@ -293,12 +293,15 @@ exportT{'si_constrep', each{rep_const, dat_types}}
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
def has_pdep = 0 # Obselete, kept here for descriptiveness
if (wv > 32) return{0}
if (not has_pdep and wv <= 8) return{0}
m:u64 = spaced_mask_of{wv}
xw:u64 = 0
d := cast_i{usz, popc{m}} # == 64/wv
nw := cdiv{rlen, 64}
if (m&1 != 0) { # Power of two
if (wv&1 != 0) {
rep_const_bool_generic_odd{wv, x, r, nw, m, d}
} else if (not has_pdep and wv <= 8) {
return{0}
} else if (m&1 != 0) { # Power of two
i := -usz~~1
def expand = if (has_pdep) pdep{., m} else {
mult:u64 = spaced_mask_of{wv-1} >> d
@ -335,6 +338,105 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
1
}
def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
# Every-k-bits mask
mask_sh := cast_i{usz, ctz{m}} # == 64%k
mask := m<<(k-mask_sh) | 1
# Transform sending bit i to k*i % 64 by pairwise swaps
# Swap data goes in a pre-computed table
def swdat{lw, k} = {
def i = iota{lw}
def bits = (k*iota{1<<lw} >> merge{0, replicate{1<<i, i}}) & 1
~u64~~base{2, bits}
}
swtab:*u64 = each{swdat{6,.}, 1+2*iota{16}}
def swap_lens = reverse{2 << iota{5}}
swap_data := load{swtab, k>>1}
swsel:u64 = ~u64~~0
def gsw{l} = {
swsel ^= swsel << l # Low l bits out of every 2*l
sm := swap_data &~ swsel
swap_data &= swsel; swap_data |= swap_data<<l
sm
}
swap_masks := each{gsw, swap_lens}
i:usz = 0
def get_swap_x{} = {
# Load x, send bit i to position k*i % 64
x := load{xp, i}; ++i
def swap_step{l, m} = {
xx := (x ^ x<<l) & m
x ^= xx | (xx>>l)
}
def swap_step{l==32, m} = { # Use rotate
mm := m | m>>l
x = (x &~ mm) | ((x<<l | x>>l) & mm)
}
each{swap_step, swap_lens, swap_masks}
x
}
# Output
j:usz = 0
def output{rw} = {
store{rp, j, rw}
++j; if (j==nw) return{1}
}
o:u64 = 0 # carry
# Dedicated loop for 3, shared for other factors
if (k == 3) {
while (1) {
x := get_swap_x{}
@unroll (jj to 3) {
b := x & mask
mask = (mask << (3 - 64%3)) | (mask >> (64%3))
def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}}
output{fold{|, (b<<3) - b, each{>>{o,.}, os}}}
o = b
}
}
} else {
# Fundamental operation: shifts act as order-k cyclic group on masks
def advance{m, sh} = m<<(k-sh) | m>>sh
sm0 := mask # starting mask
s1 := mask_sh # single iteration shift
# Get cumulative mask for 4 iterations, and shift to advance 4
def double{{mc, s}} = {
def ss = s+s
tup{mc|advance{mc,s}, ss - (k &- (ss>k))}
}
{mc4, s4} := double{double{tup{sm0, s1}}}
# Submasks pick one mask out of a combination of 4
def or_adv{m, s} = { m |= advance{m,s} }
@for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
submasks := scan{advance, tup{sm0, ...3**s1}}
mask_tail := advance{sm0, s4} &~ sm0
while (1) {
x := get_swap_x{}
mask = mc4
# Write result word given starting bits
def step{b} = {
r := (b<<k) - b
output{r | (o - promote{u64, o > 0})}
o = b>>(64-k)
}
# Fast unrolled iterations
@for (k/4) {
xm := x & mask
each{{mm} => step{xm & mm}, submasks}
mask = advance{mask, s4}
}
# Single-step for tail
mask = mask_tail
@for (k%4) {
step{x & mask}
mask = advance{mask, s1}
}
}
}
}
fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
if (wv > 32) return{0}
if (wv&1 == 0) {
@ -539,7 +641,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
# State
xv:V = V**0; o:=W**0
i:usz = 0; q:usz = 1
i:usz = 0
while (1) {
# Load xv, send bit i to position wv*i % 128
xv = load{*V~~x, i}; ++i