Extend odd/bool SSSE3 code to support AVX2
This commit is contained in:
parent
2be022921e
commit
36e9ca5814
@ -300,7 +300,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
|||||||
} else {
|
} else {
|
||||||
tlen := rlen>>p
|
tlen := rlen>>p
|
||||||
wq := usz~~1<<p
|
wq := usz~~1<<p
|
||||||
if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (not hasarch{'AVX2'} and p == 1 and wv>=24)) {
|
if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (p == 1 and wv>=24)) {
|
||||||
# Expanding odd second is faster
|
# Expanding odd second is faster
|
||||||
tlen = rlen / wf
|
tlen = rlen / wf
|
||||||
t:=wf; wf=wq; wq=t
|
t:=wf; wf=wq; wq=t
|
||||||
@ -412,8 +412,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = {
|
|||||||
# Data for the permutation that sends bit i to k*i % width{T}
|
# Data for the permutation that sends bit i to k*i % width{T}
|
||||||
def modperm_dat{T, k} = {
|
def modperm_dat{T, k} = {
|
||||||
def w = width{T}
|
def w = width{T}
|
||||||
def lw= lb{w}
|
def i = iota{lb{w}}
|
||||||
def i = iota{lw}
|
|
||||||
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
|
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
|
||||||
match (T) {
|
match (T) {
|
||||||
{[_]E} => make{T, each{base{2,.}, split{width{E}, bits}}}
|
{[_]E} => make{T, each{base{2,.}, split{width{E}, bits}}}
|
||||||
@ -432,24 +431,29 @@ def modperm_step{x:T=[_](u8), l, m} = {
|
|||||||
def W = re_el{u64, T}
|
def W = re_el{u64, T}
|
||||||
T~~modperm_step{W~~x, l, W~~m}
|
T~~modperm_step{W~~x, l, W~~m}
|
||||||
}
|
}
|
||||||
def swap_elts{x:V, el_bytes if w128{V}} = {
|
def swap_elts{x:V=[vl](u8), el_bytes} = { # Reverse each pair of elements
|
||||||
def n = 16; def I = [n]u8
|
|
||||||
def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) }
|
def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) }
|
||||||
if (el_bytes >= 4) shuf{[4]u32, x, base{4, swi{4, el_bytes/4}}}
|
if (el_bytes >= 4) {
|
||||||
else sel{I, x, make{I, swi{16, el_bytes}}}
|
def wd = max{4, el_bytes/2}
|
||||||
|
shuf{[4]ty_u{wd*8}, x, base{4, swi{4, el_bytes/wd}}}
|
||||||
|
} else {
|
||||||
|
def i = swi{16, el_bytes}
|
||||||
|
sel{[16]u8, x, make{V, cycle{vl, i}}}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
def modperm_step{x, l, m:V=[_]T if l%8==0} = {
|
def modperm_step{x, l, m:V=[_]T if l%8==0} = {
|
||||||
(x & m) | (swap_elts{x, l/8} &~ m)
|
(x & m) | (swap_elts{x, l/8} &~ m)
|
||||||
}
|
}
|
||||||
def modperm_get_byteperm{sw_bytes:V=([16]u8) if hasarch{'SSSE3'}} = {
|
def modperm_get_byteperm{sw_bytes:V=[_](u8) if hasarch{'SSSE3'}} = {
|
||||||
def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s}
|
def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s}
|
||||||
m4 := V**0xf
|
m4 := V**0xf
|
||||||
t0 := fold{{v,a}=>modperm_step{v,...a}, iota{V}, tup{
|
t0 := fold{{v,a}=>modperm_step{v,...a}, make{V,iota{vcount{V}}%16}, tup{
|
||||||
tup{4, sw_bytes &~ m4},
|
tup{4, sw_bytes &~ m4},
|
||||||
tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}}
|
tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}}
|
||||||
}}
|
}}
|
||||||
t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4}
|
t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4}
|
||||||
{xv} => sel{V, t0, xv & m4} | sel{V, t4, shW{>>, xv, 4} & m4}
|
def selI = sel{[16]u8, ., .}
|
||||||
|
{xv} => selI{t0, xv & m4} | selI{t4, shW{>>, xv, 4} & m4}
|
||||||
}
|
}
|
||||||
def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}}
|
def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}}
|
||||||
|
|
||||||
@ -477,7 +481,12 @@ def rep_const_bool_odd_mask4{
|
|||||||
tup{mn, ss - (k &- (ss>k))}
|
tup{mn, ss - (k &- (ss>k))}
|
||||||
}
|
}
|
||||||
# Mask and shift for one iteration
|
# Mask and shift for one iteration
|
||||||
{sm0, s1} := ifvec{double_gen{make{M,...}}}{tup{mask, mask_sh}}
|
def fillmask{T} = match (T) {
|
||||||
|
{(u64)} => tup{mask, mask_sh}
|
||||||
|
{[2]E} => double_gen{make{T,...}}{fillmask{E}}
|
||||||
|
{[4]E} => double_gen{pair}{fillmask{[2]E}}
|
||||||
|
}
|
||||||
|
{sm0, s1} := fillmask{M}
|
||||||
# Combined mask for 4 iterations, and shift to advance 4
|
# Combined mask for 4 iterations, and shift to advance 4
|
||||||
def double = double_gen{|}
|
def double = double_gen{|}
|
||||||
{mc4, s4} := double{double{tup{sm0, s1}}}
|
{mc4, s4} := double{double{tup{sm0, s1}}}
|
||||||
@ -488,7 +497,8 @@ def rep_const_bool_odd_mask4{
|
|||||||
mask_tail := advance{sm0, s4} &~ sm0
|
mask_tail := advance{sm0, s4} &~ sm0
|
||||||
|
|
||||||
# Carry: shifting and word-crossing is done on the initial permuted x
|
# Carry: shifting and word-crossing is done on the initial permuted x
|
||||||
o:M = scal{0} # Carry for x
|
# No need to carry across input words since they align with output words
|
||||||
|
# First bit of each word in xo below is wrong, but it doesn't matter!
|
||||||
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
||||||
def sub_carry{a, c} = match (M) {
|
def sub_carry{a, c} = match (M) {
|
||||||
{[l](u64)} => {
|
{[l](u64)} => {
|
||||||
@ -501,9 +511,9 @@ def rep_const_bool_odd_mask4{
|
|||||||
|
|
||||||
while (1) {
|
while (1) {
|
||||||
x:M = get_modperm_x{}
|
x:M = get_modperm_x{}
|
||||||
def vrot1 = ifvec{{x} => vshl{x, x, vcount{type{x}}-1}}
|
def vrot1 = ifvec{{x} => if (w128{M}) vshl{x, x, vcount{type{x}}-1}
|
||||||
k1:M = scal{1}
|
else shuf{M, x, 4b2103}}
|
||||||
os:=o; xo:=x<<k|vrot1{x>>(64-k)}; o=xo&k1; xo=(xo&~k1)|os
|
xo := x<<k | vrot1{x>>(64-k)}
|
||||||
# Write result word given starting bits
|
# Write result word given starting bits
|
||||||
def step{b, c} = output{sub_carry{c - b, c & mr}}
|
def step{b, c} = output{sub_carry{c - b, c & mr}}
|
||||||
def step{b, c, m} = step{b&m, c&m}
|
def step{b, c, m} = step{b&m, c&m}
|
||||||
@ -524,10 +534,7 @@ def rep_const_bool_odd_mask4{
|
|||||||
|
|
||||||
def rep_const_bool_odd{k, xp, rp, nw} = {
|
def rep_const_bool_odd{k, xp, rp, nw} = {
|
||||||
# Every-k-bits mask
|
# Every-k-bits mask
|
||||||
m:u64 = spaced_mask_of{k}
|
{mask, mask_sh} := unaligned_spaced_mask_mod{k}
|
||||||
d := cast_i{usz, popc{m}} # == 64/k
|
|
||||||
mask_sh := cast_i{usz, ctz{m}} # == 64%k
|
|
||||||
mask := m<<(k-mask_sh) | 1
|
|
||||||
|
|
||||||
# Transform sending bit i to k*i % 64 by pairwise swaps
|
# Transform sending bit i to k*i % 64 by pairwise swaps
|
||||||
# Swap data goes in a pre-computed table
|
# Swap data goes in a pre-computed table
|
||||||
@ -553,19 +560,18 @@ def rep_const_bool_odd{k, xp, rp, nw} = {
|
|||||||
store{rp, j, rw}
|
store{rp, j, rw}
|
||||||
++j; if (j==nw) return{1}
|
++j; if (j==nw) return{1}
|
||||||
}
|
}
|
||||||
o:u64 = 0 # carry
|
|
||||||
# Dedicated loop for 3, shared for other factors
|
# Dedicated loop for 3, shared for other factors
|
||||||
if (k == 3) {
|
if (k == 3) while (1) {
|
||||||
def unrolled_iter{k} = {
|
def unrolled_iter{k} = {
|
||||||
def reps{f, n, init} = scan{{x,_}=>f{x}, n**init}
|
def reps{f, n, init} = scan{{x,_}=>f{x}, n**init}
|
||||||
def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask}
|
def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask}
|
||||||
def step{x}{m, n_over} = {
|
def step{x}{o, m, n_over} = {
|
||||||
b := x & m
|
b := x & m
|
||||||
def os = (64-k) + 1 + iota{n_over}
|
def os = (64-k) + 1 + iota{n_over}
|
||||||
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
|
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
|
||||||
o = b
|
b
|
||||||
}
|
}
|
||||||
while (1) each{step{get_swap_x{}}, masks, (-iota{k}*64)%k}
|
fold_multi{step{get_swap_x{}}, 0, masks, (-iota{k}*64)%k}
|
||||||
}
|
}
|
||||||
unrolled_iter{3}
|
unrolled_iter{3}
|
||||||
} else {
|
} else {
|
||||||
@ -578,21 +584,45 @@ def rep_const_bool_odd{k, xp, rp, nw} = {
|
|||||||
# - replicate each byte by k, making position k*i contain bit i
|
# - replicate each byte by k, making position k*i contain bit i
|
||||||
# - mask out those bits and spread over [ k*i, k*(i+1) )
|
# - mask out those bits and spread over [ k*i, k*(i+1) )
|
||||||
# - ...except where it crosses words; handle this overhang separately
|
# - ...except where it crosses words; handle this overhang separately
|
||||||
def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15
|
def rep_const_bool_odd{k, x, r, nw if hasarch{'SSSE3'}} = {
|
||||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
def avx2 = hasarch{'AVX2'}
|
||||||
def vl = 16; def V = [vl]u8
|
def vl = if (avx2) 32 else 16; def V = [vl]u8
|
||||||
def iV = iota{vl}
|
def iV = iota{vl}
|
||||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
def mkV = make{V, .}; def selV = sel{[16]u8, ., .}
|
||||||
def W = [2]u64
|
def W = re_el{u64, V}
|
||||||
def {output, flush} = get_boolvec_writer{V, r, nw}
|
def {output, flush} = get_boolvec_writer{W, r, nw}
|
||||||
|
|
||||||
# Swap data goes in a pre-computed table
|
# Swap data goes in a pre-computed table
|
||||||
swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}}
|
swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}}
|
||||||
swap_data := load{swtab, wv>>1}
|
swap_data := load{swtab, k>>1}
|
||||||
|
swap_lane := if (avx2) shuf{[4]u64, swap_data, 4b1010} else swap_data
|
||||||
# Within-byte transformation
|
# Within-byte transformation
|
||||||
def perm_x = modperm_get_byteperm{selV{swap_data, V**0}}
|
def perm_x = modperm_get_byteperm{selV{swap_lane, V**0}}
|
||||||
|
def sp_max = if (avx2) 4 else 8
|
||||||
|
if (k < sp_max) {
|
||||||
|
rep_const_bool_odd_special{V, sp_max, k, x, perm_x, {v}=>output{W~~v}}
|
||||||
|
} else {
|
||||||
|
{m, d} := unaligned_spaced_mask_mod{k}
|
||||||
|
# General case
|
||||||
|
def get_mask{l} = selV{~swap_lane, mkV{l+iV%l}}
|
||||||
|
def get_mask{16} = shuf{[4]u64, ~swap_data, 4b3232}
|
||||||
|
def swap_lens = reverse{1 << iota{4 + avx2}}
|
||||||
|
swap_masks := each{get_mask, swap_lens}
|
||||||
|
def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks}
|
||||||
|
i:usz = 0
|
||||||
|
def get_swap_x{} = { xv := W~~perm_x{swap_x{load{*V~~x, i}}}; ++i; xv }
|
||||||
|
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}, m, d}
|
||||||
|
}
|
||||||
|
flush{}
|
||||||
|
}
|
||||||
|
|
||||||
if (wv < 4) {
|
def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = {
|
||||||
|
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||||
|
def iV = iota{vl}
|
||||||
|
def mkV = make{V, .}; def selV = sel{[16]u8, ., .}
|
||||||
|
def W = re_el{u64, V}
|
||||||
|
if (max_wv <= 4 or wv < 4) {
|
||||||
|
def ll = 16; def dup{v} = if (vl>ll) merge{v,v} else v
|
||||||
# 3: dedicated loop
|
# 3: dedicated loop
|
||||||
i:usz = 0; while (1) {
|
i:usz = 0; while (1) {
|
||||||
# 01234567 to 05316427 on each byte
|
# 01234567 to 05316427 on each byte
|
||||||
@ -601,24 +631,29 @@ def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15
|
|||||||
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
|
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
|
||||||
def ib = ix // 8 # byte index
|
def ib = ix // 8 # byte index
|
||||||
def io = 8*ib + 3*ix%8 # where they are in xv
|
def io = 8*ib + 3*ix%8 # where they are in xv
|
||||||
def wi = split{2, tup{255, ...ib, 255, ...8+ib}}
|
def wi = split{vl/8, dup{tup{255, ...ib, 255, ...8+ib}}}
|
||||||
xo := V~~((W~~xv & W**fold{|, 1<<io}) >> (8-3))
|
xo := V~~((W~~xv & W**fold{|, 1<<io}) >> (8-3))
|
||||||
xo += xo > V**0
|
xo += xo > V**0
|
||||||
# Permute and mask bytes
|
# Permute and mask bytes
|
||||||
def step{jj, oi, ind, mask} = {
|
def step{jj, oi, ind, mask} = {
|
||||||
b := W~~(selV{xv, ind} & mask)
|
def getv = if (vl==ll or jj==1) ({x}=>x)
|
||||||
|
else shuf{[4]u64, ., 4b1010 + 4b2222*(jj>1)}
|
||||||
|
def selx{x, i} = sel{[ll]u8, getv{x}, i}
|
||||||
|
b := W~~(selx{xv, ind} & mask)
|
||||||
r := V~~((b<<3) - b)
|
r := V~~((b<<3) - b)
|
||||||
o := selV{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
|
o := selx{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
|
||||||
output{r|o}
|
output{r|o}
|
||||||
}
|
}
|
||||||
def make3V{vs} = each{make{V,.}, split{vl, vs}}
|
def make3V{vs} = each{make{V,.}, split{vl, dup{vs}}}
|
||||||
each{step,
|
each{step,
|
||||||
iota{3}, wi,
|
iota{3}, wi,
|
||||||
make3V{replicate{3, iota{vl}}},
|
make3V{replicate{3, iota{ll}}},
|
||||||
make3V{8w2b001 << ((-8)*iota{3*vl} % 3)}
|
make3V{8w2b001 << ((-8)*iota{3*ll} % 3)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (wv < 8) {
|
} else {
|
||||||
|
assert{w128{V}}
|
||||||
|
def selV = sel{V, ., .}
|
||||||
# 5, 7: precompute constants, then shared loop
|
# 5, 7: precompute constants, then shared loop
|
||||||
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
|
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
|
||||||
W, V, V, V, V, V, usz }}
|
W, V, V, V, V, V, usz }}
|
||||||
@ -663,17 +698,7 @@ def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15
|
|||||||
o := xo & mkV{255 * (iV%8 == 0)} # overhang
|
o := xo & mkV{255 * (iV%8 == 0)} # overhang
|
||||||
output{rv | o}
|
output{rv | o}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
# General case
|
|
||||||
def swap_lens = reverse{1 << iota{4}}
|
|
||||||
swap_masks := each{{l} => selV{~swap_data, mkV{l+iV%l}}, swap_lens}
|
|
||||||
def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks}
|
|
||||||
i:usz = 0
|
|
||||||
def get_swap_x{} = { xv := perm_x{swap_x{load{*V~~x, i}}}; ++i; xv }
|
|
||||||
# Every-k-bits mask
|
|
||||||
{m, d} := unaligned_spaced_mask_mod{wv}
|
|
||||||
rep_const_bool_odd_mask4{W, wv, {}=>W~~get_swap_x{}, {v}=>output{V~~v}, cdiv{nw, vcount{W}}, m, d}
|
|
||||||
}
|
}
|
||||||
flush{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export{'si_constrep_bool', rep_const_bool{}}
|
export{'si_constrep_bool', rep_const_bool{}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user