Extend SSSE3 k/bool to k<32 using pairwise swaps
This commit is contained in:
parent
696a23af9e
commit
7d7d36b354
@ -291,25 +291,8 @@ exportT{'si_constrep', each{rep_const, dat_types}}
|
||||
|
||||
# Constant replicate on boolean
|
||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
def has_pdep = hasarch{'BMI2'}
|
||||
def has_pdep = 0 # Obselete, kept here for descriptiveness
|
||||
if (wv > 32) return{0}
|
||||
if (hasarch{'SSSE3'}) {
|
||||
if (wv&1 == 0) {
|
||||
p := ctz{wv | 8} # Power of two for second replicate
|
||||
if (wv>>p == 1) {
|
||||
rep_const_bool_ssse3_div8{wv, x, r, rlen}
|
||||
} else {
|
||||
tlen := rlen>>p
|
||||
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||
rep_const_bool{}(wv >>p, x, t, tlen)
|
||||
rep_const_bool{}(usz~~1<<p, t, r, rlen)
|
||||
}
|
||||
return{1}
|
||||
} else if (wv < 16) {
|
||||
rep_const_bool_ssse3_odd{wv, x, r, rlen}
|
||||
return{1}
|
||||
}
|
||||
}
|
||||
if (not has_pdep and wv <= 8) return{0}
|
||||
m:u64 = spaced_mask_of{wv}
|
||||
xw:u64 = 0
|
||||
@ -352,6 +335,24 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
1
|
||||
}
|
||||
|
||||
fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
if (wv > 32) return{0}
|
||||
if (wv&1 == 0) {
|
||||
p := ctz{wv | 8} # Power of two for second replicate
|
||||
if (wv>>p == 1) {
|
||||
rep_const_bool_ssse3_div8{wv, x, r, rlen}
|
||||
} else {
|
||||
tlen := rlen>>p
|
||||
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||
rep_const_bool{}(wv >>p, x, t, tlen)
|
||||
rep_const_bool{}(usz~~1<<p, t, r, rlen)
|
||||
}
|
||||
} else {
|
||||
rep_const_bool_ssse3_odd{wv, x, r, rlen}
|
||||
}
|
||||
1
|
||||
}
|
||||
|
||||
# Generalized flat transpose of iota{1<<length{bs}}
|
||||
# select{tr_iota{bs}, x} sends bit i of x to position select{bs, i}
|
||||
def tr_iota{...bs} = {
|
||||
@ -445,8 +446,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
|
||||
{t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
|
||||
m4 := V**0xf
|
||||
def get_perm_x{i} = {
|
||||
xv := load{*V~~x, i}
|
||||
def perm_x{xv} = {
|
||||
selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
|
||||
}
|
||||
|
||||
@ -455,7 +455,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
# 3: dedicated loop
|
||||
i:usz = 0; while (1) {
|
||||
# 01234567 to 05316427 on each byte
|
||||
xv := get_perm_x{i}; ++i
|
||||
xv := perm_x{load{*V~~x, i}}; ++i
|
||||
# Overhang from previous 64-bit elements
|
||||
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
|
||||
def ib = ix // 8 # byte index
|
||||
@ -503,7 +503,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
while (1) {
|
||||
--q; if (q == 0) { q = wv
|
||||
# Load and permute bytes
|
||||
xv = get_perm_x{i}; ++i
|
||||
xv = perm_x{load{*V~~x, i}}; ++i
|
||||
# Bytes for overhang
|
||||
xo = V~~((W~~xv & xom) >> (8 - wv))
|
||||
xo += xo > V**0
|
||||
@ -522,35 +522,42 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
o := xo & mkV{255 * (iV%8 == 0)} # overhang
|
||||
output{rv | o}
|
||||
}
|
||||
} else { # wv < 16
|
||||
# 9, 11, 13, 15: shared constant computation and loop
|
||||
# Initializers, relying on vl%wv == vl-wv in various ways
|
||||
wV:= V**cast_i{u8,wv}
|
||||
ind0 := (iota{V} < wV) + V**1 # mkV{iV >= k}
|
||||
} else { # wv < 32
|
||||
# 9 to 31: extend k*i % 8 transform to k*i % 128 by pairwise swaps
|
||||
# Swap data goes in a pre-computed table
|
||||
def swdat{k} = {
|
||||
def bits = (k*iota{128} >> merge{0, replicate{1<<iota{7}, iota{7}}}) & 1
|
||||
mkV{each{base{2, .}, split{8, bits}}}
|
||||
}
|
||||
swtab:*V = each{swdat, 9+2*iota{12}}
|
||||
def swap_lens = reverse{1 << iota{4}}
|
||||
swap_data := load{swtab, (wv>>1)-4}
|
||||
swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens}
|
||||
# Every-k-bits mask, same as before
|
||||
{m, d} := unaligned_spaced_mask_mod{wv}
|
||||
mask0:= V~~make{W, m, m>>d|m<<(wv-d)}
|
||||
iu:= iota{V} + V**cast_i{u8,vl-wv}; ia:= iu>=V**vl
|
||||
ind_up := iu - (wV&ia)
|
||||
ind_inc:= V**(wv <= vl) - ia
|
||||
mask_sh:= d+d; if (mask_sh >= wv) mask_sh-= wv
|
||||
mask := V~~make{W, m, m>>d|m<<(wv-d)}
|
||||
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
|
||||
# State
|
||||
xv:V = V**0; o:=W**0; ind:=xv; mask:=xv
|
||||
xv:V = V**0; o:=W**0
|
||||
i:usz = 0; q:usz = 1
|
||||
while (1) {
|
||||
--q; if (q == 0) { q = wv
|
||||
xv = get_perm_x{i}; ++i
|
||||
ind = ind0
|
||||
mask = mask0
|
||||
} else {
|
||||
ind = selV{ind, ind_up} + ind_inc
|
||||
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
|
||||
# Load xv, send bit i to position wv*i % 128
|
||||
xv = load{*V~~x, i}; ++i
|
||||
def swap_step{l, m} = {
|
||||
xv = (xv & m) | (selV{xv, mkV{iV + (l-2*(iV&l))}} &~ m)
|
||||
}
|
||||
each{swap_step, swap_lens, swap_masks}
|
||||
xv = perm_x{xv}
|
||||
# Write wv vectors based on that
|
||||
@for (wv) {
|
||||
b := W~~(xv & mask)
|
||||
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
|
||||
rv:= V~~((b<<wv) - b)
|
||||
# Handle overhang here; won't fit in a single vector
|
||||
po:= o; o = b>>(64-wv)
|
||||
ro:= [4]u32~~vshl{po, o, 1}
|
||||
output{rv | V~~(ro + (ro > [4]u32**0))}
|
||||
}
|
||||
# Handle overhang here; won't fit in a single vector
|
||||
b := W~~(selV{xv, ind} & mask)
|
||||
rv:= V~~((b<<wv) - b)
|
||||
po:= o; o = b>>(64-wv)
|
||||
ro:= [4]u32~~vshl{po, o, 1}
|
||||
output{rv | V~~(ro + (ro > [4]u32**0))}
|
||||
}
|
||||
}
|
||||
flush{}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user