Generic-arch k/bool algorithm for odd k with mask unpacking
This commit is contained in:
parent
7d7d36b354
commit
1621c7d07c
@ -293,12 +293,15 @@ exportT{'si_constrep', each{rep_const, dat_types}}
|
|||||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||||
def has_pdep = 0 # Obselete, kept here for descriptiveness
|
def has_pdep = 0 # Obselete, kept here for descriptiveness
|
||||||
if (wv > 32) return{0}
|
if (wv > 32) return{0}
|
||||||
if (not has_pdep and wv <= 8) return{0}
|
|
||||||
m:u64 = spaced_mask_of{wv}
|
m:u64 = spaced_mask_of{wv}
|
||||||
xw:u64 = 0
|
xw:u64 = 0
|
||||||
d := cast_i{usz, popc{m}} # == 64/wv
|
d := cast_i{usz, popc{m}} # == 64/wv
|
||||||
nw := cdiv{rlen, 64}
|
nw := cdiv{rlen, 64}
|
||||||
if (m&1 != 0) { # Power of two
|
if (wv&1 != 0) {
|
||||||
|
rep_const_bool_generic_odd{wv, x, r, nw, m, d}
|
||||||
|
} else if (not has_pdep and wv <= 8) {
|
||||||
|
return{0}
|
||||||
|
} else if (m&1 != 0) { # Power of two
|
||||||
i := -usz~~1
|
i := -usz~~1
|
||||||
def expand = if (has_pdep) pdep{., m} else {
|
def expand = if (has_pdep) pdep{., m} else {
|
||||||
mult:u64 = spaced_mask_of{wv-1} >> d
|
mult:u64 = spaced_mask_of{wv-1} >> d
|
||||||
@ -335,6 +338,105 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
|||||||
1
|
1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
|
||||||
|
# Every-k-bits mask
|
||||||
|
mask_sh := cast_i{usz, ctz{m}} # == 64%k
|
||||||
|
mask := m<<(k-mask_sh) | 1
|
||||||
|
|
||||||
|
# Transform sending bit i to k*i % 64 by pairwise swaps
|
||||||
|
# Swap data goes in a pre-computed table
|
||||||
|
def swdat{lw, k} = {
|
||||||
|
def i = iota{lw}
|
||||||
|
def bits = (k*iota{1<<lw} >> merge{0, replicate{1<<i, i}}) & 1
|
||||||
|
~u64~~base{2, bits}
|
||||||
|
}
|
||||||
|
swtab:*u64 = each{swdat{6,.}, 1+2*iota{16}}
|
||||||
|
def swap_lens = reverse{2 << iota{5}}
|
||||||
|
swap_data := load{swtab, k>>1}
|
||||||
|
swsel:u64 = ~u64~~0
|
||||||
|
def gsw{l} = {
|
||||||
|
swsel ^= swsel << l # Low l bits out of every 2*l
|
||||||
|
sm := swap_data &~ swsel
|
||||||
|
swap_data &= swsel; swap_data |= swap_data<<l
|
||||||
|
sm
|
||||||
|
}
|
||||||
|
swap_masks := each{gsw, swap_lens}
|
||||||
|
i:usz = 0
|
||||||
|
def get_swap_x{} = {
|
||||||
|
# Load x, send bit i to position k*i % 64
|
||||||
|
x := load{xp, i}; ++i
|
||||||
|
def swap_step{l, m} = {
|
||||||
|
xx := (x ^ x<<l) & m
|
||||||
|
x ^= xx | (xx>>l)
|
||||||
|
}
|
||||||
|
def swap_step{l==32, m} = { # Use rotate
|
||||||
|
mm := m | m>>l
|
||||||
|
x = (x &~ mm) | ((x<<l | x>>l) & mm)
|
||||||
|
}
|
||||||
|
each{swap_step, swap_lens, swap_masks}
|
||||||
|
x
|
||||||
|
}
|
||||||
|
|
||||||
|
# Output
|
||||||
|
j:usz = 0
|
||||||
|
def output{rw} = {
|
||||||
|
store{rp, j, rw}
|
||||||
|
++j; if (j==nw) return{1}
|
||||||
|
}
|
||||||
|
o:u64 = 0 # carry
|
||||||
|
# Dedicated loop for 3, shared for other factors
|
||||||
|
if (k == 3) {
|
||||||
|
while (1) {
|
||||||
|
x := get_swap_x{}
|
||||||
|
@unroll (jj to 3) {
|
||||||
|
b := x & mask
|
||||||
|
mask = (mask << (3 - 64%3)) | (mask >> (64%3))
|
||||||
|
def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}}
|
||||||
|
output{fold{|, (b<<3) - b, each{>>{o,.}, os}}}
|
||||||
|
o = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
# Fundamental operation: shifts act as order-k cyclic group on masks
|
||||||
|
def advance{m, sh} = m<<(k-sh) | m>>sh
|
||||||
|
sm0 := mask # starting mask
|
||||||
|
s1 := mask_sh # single iteration shift
|
||||||
|
# Get cumulative mask for 4 iterations, and shift to advance 4
|
||||||
|
def double{{mc, s}} = {
|
||||||
|
def ss = s+s
|
||||||
|
tup{mc|advance{mc,s}, ss - (k &- (ss>k))}
|
||||||
|
}
|
||||||
|
{mc4, s4} := double{double{tup{sm0, s1}}}
|
||||||
|
# Submasks pick one mask out of a combination of 4
|
||||||
|
def or_adv{m, s} = { m |= advance{m,s} }
|
||||||
|
@for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
|
||||||
|
submasks := scan{advance, tup{sm0, ...3**s1}}
|
||||||
|
mask_tail := advance{sm0, s4} &~ sm0
|
||||||
|
while (1) {
|
||||||
|
x := get_swap_x{}
|
||||||
|
mask = mc4
|
||||||
|
# Write result word given starting bits
|
||||||
|
def step{b} = {
|
||||||
|
r := (b<<k) - b
|
||||||
|
output{r | (o - promote{u64, o > 0})}
|
||||||
|
o = b>>(64-k)
|
||||||
|
}
|
||||||
|
# Fast unrolled iterations
|
||||||
|
@for (k/4) {
|
||||||
|
xm := x & mask
|
||||||
|
each{{mm} => step{xm & mm}, submasks}
|
||||||
|
mask = advance{mask, s4}
|
||||||
|
}
|
||||||
|
# Single-step for tail
|
||||||
|
mask = mask_tail
|
||||||
|
@for (k%4) {
|
||||||
|
step{x & mask}
|
||||||
|
mask = advance{mask, s1}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||||
if (wv > 32) return{0}
|
if (wv > 32) return{0}
|
||||||
if (wv&1 == 0) {
|
if (wv&1 == 0) {
|
||||||
@ -539,7 +641,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
|||||||
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
|
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
|
||||||
# State
|
# State
|
||||||
xv:V = V**0; o:=W**0
|
xv:V = V**0; o:=W**0
|
||||||
i:usz = 0; q:usz = 1
|
i:usz = 0
|
||||||
while (1) {
|
while (1) {
|
||||||
# Load xv, send bit i to position wv*i % 128
|
# Load xv, send bit i to position wv*i % 128
|
||||||
xv = load{*V~~x, i}; ++i
|
xv = load{*V~~x, i}; ++i
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user