Generic-arch k/bool algorithm for odd k with mask unpacking
This commit is contained in:
parent
7d7d36b354
commit
1621c7d07c
@ -293,12 +293,15 @@ exportT{'si_constrep', each{rep_const, dat_types}}
|
||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
def has_pdep = 0 # Obselete, kept here for descriptiveness
|
||||
if (wv > 32) return{0}
|
||||
if (not has_pdep and wv <= 8) return{0}
|
||||
m:u64 = spaced_mask_of{wv}
|
||||
xw:u64 = 0
|
||||
d := cast_i{usz, popc{m}} # == 64/wv
|
||||
nw := cdiv{rlen, 64}
|
||||
if (m&1 != 0) { # Power of two
|
||||
if (wv&1 != 0) {
|
||||
rep_const_bool_generic_odd{wv, x, r, nw, m, d}
|
||||
} else if (not has_pdep and wv <= 8) {
|
||||
return{0}
|
||||
} else if (m&1 != 0) { # Power of two
|
||||
i := -usz~~1
|
||||
def expand = if (has_pdep) pdep{., m} else {
|
||||
mult:u64 = spaced_mask_of{wv-1} >> d
|
||||
@ -335,6 +338,105 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
1
|
||||
}
|
||||
|
||||
def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
|
||||
# Every-k-bits mask
|
||||
mask_sh := cast_i{usz, ctz{m}} # == 64%k
|
||||
mask := m<<(k-mask_sh) | 1
|
||||
|
||||
# Transform sending bit i to k*i % 64 by pairwise swaps
|
||||
# Swap data goes in a pre-computed table
|
||||
def swdat{lw, k} = {
|
||||
def i = iota{lw}
|
||||
def bits = (k*iota{1<<lw} >> merge{0, replicate{1<<i, i}}) & 1
|
||||
~u64~~base{2, bits}
|
||||
}
|
||||
swtab:*u64 = each{swdat{6,.}, 1+2*iota{16}}
|
||||
def swap_lens = reverse{2 << iota{5}}
|
||||
swap_data := load{swtab, k>>1}
|
||||
swsel:u64 = ~u64~~0
|
||||
def gsw{l} = {
|
||||
swsel ^= swsel << l # Low l bits out of every 2*l
|
||||
sm := swap_data &~ swsel
|
||||
swap_data &= swsel; swap_data |= swap_data<<l
|
||||
sm
|
||||
}
|
||||
swap_masks := each{gsw, swap_lens}
|
||||
i:usz = 0
|
||||
def get_swap_x{} = {
|
||||
# Load x, send bit i to position k*i % 64
|
||||
x := load{xp, i}; ++i
|
||||
def swap_step{l, m} = {
|
||||
xx := (x ^ x<<l) & m
|
||||
x ^= xx | (xx>>l)
|
||||
}
|
||||
def swap_step{l==32, m} = { # Use rotate
|
||||
mm := m | m>>l
|
||||
x = (x &~ mm) | ((x<<l | x>>l) & mm)
|
||||
}
|
||||
each{swap_step, swap_lens, swap_masks}
|
||||
x
|
||||
}
|
||||
|
||||
# Output
|
||||
j:usz = 0
|
||||
def output{rw} = {
|
||||
store{rp, j, rw}
|
||||
++j; if (j==nw) return{1}
|
||||
}
|
||||
o:u64 = 0 # carry
|
||||
# Dedicated loop for 3, shared for other factors
|
||||
if (k == 3) {
|
||||
while (1) {
|
||||
x := get_swap_x{}
|
||||
@unroll (jj to 3) {
|
||||
b := x & mask
|
||||
mask = (mask << (3 - 64%3)) | (mask >> (64%3))
|
||||
def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}}
|
||||
output{fold{|, (b<<3) - b, each{>>{o,.}, os}}}
|
||||
o = b
|
||||
}
|
||||
}
|
||||
} else {
|
||||
# Fundamental operation: shifts act as order-k cyclic group on masks
|
||||
def advance{m, sh} = m<<(k-sh) | m>>sh
|
||||
sm0 := mask # starting mask
|
||||
s1 := mask_sh # single iteration shift
|
||||
# Get cumulative mask for 4 iterations, and shift to advance 4
|
||||
def double{{mc, s}} = {
|
||||
def ss = s+s
|
||||
tup{mc|advance{mc,s}, ss - (k &- (ss>k))}
|
||||
}
|
||||
{mc4, s4} := double{double{tup{sm0, s1}}}
|
||||
# Submasks pick one mask out of a combination of 4
|
||||
def or_adv{m, s} = { m |= advance{m,s} }
|
||||
@for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
|
||||
submasks := scan{advance, tup{sm0, ...3**s1}}
|
||||
mask_tail := advance{sm0, s4} &~ sm0
|
||||
while (1) {
|
||||
x := get_swap_x{}
|
||||
mask = mc4
|
||||
# Write result word given starting bits
|
||||
def step{b} = {
|
||||
r := (b<<k) - b
|
||||
output{r | (o - promote{u64, o > 0})}
|
||||
o = b>>(64-k)
|
||||
}
|
||||
# Fast unrolled iterations
|
||||
@for (k/4) {
|
||||
xm := x & mask
|
||||
each{{mm} => step{xm & mm}, submasks}
|
||||
mask = advance{mask, s4}
|
||||
}
|
||||
# Single-step for tail
|
||||
mask = mask_tail
|
||||
@for (k%4) {
|
||||
step{x & mask}
|
||||
mask = advance{mask, s1}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
if (wv > 32) return{0}
|
||||
if (wv&1 == 0) {
|
||||
@ -539,7 +641,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
|
||||
# State
|
||||
xv:V = V**0; o:=W**0
|
||||
i:usz = 0; q:usz = 1
|
||||
i:usz = 0
|
||||
while (1) {
|
||||
# Load xv, send bit i to position wv*i % 128
|
||||
xv = load{*V~~x, i}; ++i
|
||||
|
||||
Loading…
Reference in New Issue
Block a user