NEON k/bool support
This commit is contained in:
parent
b801c7c186
commit
b2566c8f3a
@ -289,6 +289,12 @@ exportT{'si_constrep', each{rep_const, dat_types}}
|
|||||||
|
|
||||||
|
|
||||||
# Constant replicate on boolean
|
# Constant replicate on boolean
|
||||||
|
def any_sel = hasarch{'SSSE3'} or hasarch{'AARCH64'}
|
||||||
|
if_inline (hasarch{'AARCH64'}) {
|
||||||
|
def __shl{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,b}
|
||||||
|
def __shr{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,-b}
|
||||||
|
}
|
||||||
|
|
||||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
||||||
assert{wv >= 2}; assert{wv <= 64}
|
assert{wv >= 2}; assert{wv <= 64}
|
||||||
nw := cdiv{rlen, 64}
|
nw := cdiv{rlen, 64}
|
||||||
@ -297,7 +303,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
|||||||
wf := wv>>p
|
wf := wv>>p
|
||||||
if (wf == 1) {
|
if (wf == 1) {
|
||||||
rep_const_bool_div8{wv, x, r, nw}
|
rep_const_bool_div8{wv, x, r, nw}
|
||||||
} else if (hasarch{'SSSE3'} and p == 3 and wv <= 8*select{basic_rep, -1}) {
|
} else if (any_sel and p == 3 and wv <= 8*select{basic_rep, -1}) {
|
||||||
# (higher wv double-factors, which doesn't work with in-place pointers)
|
# (higher wv double-factors, which doesn't work with in-place pointers)
|
||||||
tlen := rlen / wf
|
tlen := rlen / wf
|
||||||
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||||
@ -306,7 +312,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
|||||||
} else {
|
} else {
|
||||||
tlen := rlen >> p
|
tlen := rlen >> p
|
||||||
wq := usz~~1 << p
|
wq := usz~~1 << p
|
||||||
if (p == 1 and (not hasarch{'SSSE3'} or wv>=24)) {
|
if (p == 1 and (not any_sel or wv>=24)) {
|
||||||
# Expanding odd second is faster
|
# Expanding odd second is faster
|
||||||
tlen = rlen / wf
|
tlen = rlen / wf
|
||||||
t:=wf; wf=wq; wq=t
|
t:=wf; wf=wq; wq=t
|
||||||
@ -383,7 +389,7 @@ def get_boolvec_writer{T=(u64), r:*T, nw} = {
|
|||||||
tup{output, {}=>{}, flush}
|
tup{output, {}=>{}, flush}
|
||||||
}
|
}
|
||||||
|
|
||||||
def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
|
def rep_const_bool_div8{wv, x, r, nw if has_simd} = {
|
||||||
def avx2 = hasarch{'AVX2'}
|
def avx2 = hasarch{'AVX2'}
|
||||||
def vl = if (avx2) 32 else 16
|
def vl = if (avx2) 32 else 16
|
||||||
def V = [vl]u8
|
def V = [vl]u8
|
||||||
@ -414,7 +420,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
|
|||||||
output{(xe & mkV{1 << (iV % 8)}) > V**0}
|
output{(xe & mkV{1 << (iV % 8)}) > V**0}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (not hasarch{'SSSE3'}) {
|
if (not any_sel) {
|
||||||
if (wv == 2) {
|
if (wv == 2) {
|
||||||
run24{*V~~x, {xv} => {
|
run24{*V~~x, {xv} => {
|
||||||
def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<<l) }
|
def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<<l) }
|
||||||
@ -435,7 +441,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
|
|||||||
def z{x} = zip{x,x,0}
|
def z{x} = zip{x,x,0}
|
||||||
run8{{x} => z{z{z{x}}}}
|
run8{{x} => z{z{z{x}}}}
|
||||||
}
|
}
|
||||||
} else { # hasarch{'SSSE3'}
|
} else { # any_sel
|
||||||
def selH = sel{[16]u8, ., .}
|
def selH = sel{[16]u8, ., .}
|
||||||
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
|
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
|
||||||
def id{xv} = xv
|
def id{xv} = xv
|
||||||
@ -502,9 +508,9 @@ def swap_elts{x:V=[_]_, el_bytes} = {
|
|||||||
def rev = i + (l - 2*(i&l))
|
def rev = i + (l - 2*(i&l))
|
||||||
sel_imm{[n]ty_u{wd*8}, x, rev}
|
sel_imm{[n]ty_u{wd*8}, x, rev}
|
||||||
}
|
}
|
||||||
if (el_bytes >= 4) {
|
if (hasarch{'SSE2'} and el_bytes >= 4) {
|
||||||
selx{4, max{4, el_bytes/2}}
|
selx{4, max{4, el_bytes/2}}
|
||||||
} else if (hasarch{'SSSE3'}) {
|
} else if (any_sel) {
|
||||||
selx{16, 1}
|
selx{16, 1}
|
||||||
} else {
|
} else {
|
||||||
def sh = el_bytes*8
|
def sh = el_bytes*8
|
||||||
@ -515,7 +521,8 @@ def swap_elts{x:V=[_]_, el_bytes} = {
|
|||||||
# Extract swap functions from modperm_dat
|
# Extract swap functions from modperm_dat
|
||||||
def extract_modperm_mask{full:W, lane, l} = {
|
def extract_modperm_mask{full:W, lane, l} = {
|
||||||
def f = l == 128 # Use full width
|
def f = l == 128 # Use full width
|
||||||
def w = if (l<32) 8 else if (f) 64 else 32 # Shuffle element width
|
def w = if (l<32 or not hasarch{'SSE2'}) 8
|
||||||
|
else if (f) 64 else 32 # Shuffle element width
|
||||||
def n = (if (f) 256 else 128) / w # and number of elements
|
def n = (if (f) 256 else 128) / w # and number of elements
|
||||||
def o = l / w # Extract [o,2*o)
|
def o = l / w # Extract [o,2*o)
|
||||||
def i = o + iota{n}%o
|
def i = o + iota{n}%o
|
||||||
@ -556,9 +563,9 @@ def proc_mod_dat{swap_data:W} = {
|
|||||||
}
|
}
|
||||||
def {partwidth, partperm} = {
|
def {partwidth, partperm} = {
|
||||||
def sh{w, d} = tup{w, get_shiftperm{d, w}}
|
def sh{w, d} = tup{w, get_shiftperm{d, w}}
|
||||||
if (ww==64) sh{64, swap_data}
|
if (ww==64) sh{64, swap_data}
|
||||||
else if (not hasarch{'SSSE3'}) sh{32, shuf{[4]u32, swap_data, 4b0000}}
|
else if (not any_sel) sh{32, shuf{[4]u32, swap_data, 4b0000}}
|
||||||
else tup{8, get_byteperm{}}
|
else tup{8, get_byteperm{}}
|
||||||
}
|
}
|
||||||
# Fill in higher steps
|
# Fill in higher steps
|
||||||
def get_mod_permuter{} = {
|
def get_mod_permuter{} = {
|
||||||
@ -574,7 +581,7 @@ def proc_mod_dat{swap_data:W} = {
|
|||||||
|
|
||||||
def rep_const_bool_odd{k, x, r, nw} = {
|
def rep_const_bool_odd{k, x, r, nw} = {
|
||||||
def avx2 = hasarch{'AVX2'}
|
def avx2 = hasarch{'AVX2'}
|
||||||
def W = if (hasarch{'SSE2'}) [if (avx2) 4 else 2]u64 else u64
|
def W = if (has_simd) [if (avx2) 4 else 2]u64 else u64
|
||||||
|
|
||||||
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
|
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
|
||||||
xp := *W~~x
|
xp := *W~~x
|
||||||
@ -586,7 +593,7 @@ def rep_const_bool_odd{k, x, r, nw} = {
|
|||||||
swap_data := load{swtab, k>>1}
|
swap_data := load{swtab, k>>1}
|
||||||
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
|
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
|
||||||
|
|
||||||
def sp_max = if (hasarch{'SSSE3'}) 8 else 4
|
def sp_max = if (any_sel) 8 else 4
|
||||||
if (k < sp_max) {
|
if (k < sp_max) {
|
||||||
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
|
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
|
||||||
} else {
|
} else {
|
||||||
@ -639,7 +646,7 @@ def rep_const_bool_odd_mask4{
|
|||||||
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
||||||
def sub_carry{a, c} = match (M) {
|
def sub_carry{a, c} = match (M) {
|
||||||
{[l](u64)} => {
|
{[l](u64)} => {
|
||||||
ca := if (hasarch{'SSE4.2'}) { def S = [l]i64; S~~c > S**0 }
|
ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 }
|
||||||
else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} }
|
else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} }
|
||||||
a + M~~ca
|
a + M~~ca
|
||||||
}
|
}
|
||||||
@ -648,8 +655,11 @@ def rep_const_bool_odd_mask4{
|
|||||||
|
|
||||||
while (1) {
|
while (1) {
|
||||||
x:M = get_modperm_x{}
|
x:M = get_modperm_x{}
|
||||||
def vrot1 = ifvec{{x} => if (w128{M}) shuf{[4]u32, x, 4b1032}
|
def vrot1 = ifvec{{x} => {
|
||||||
else shuf{M, x, 4b2103}}
|
if (w256{M}) shuf{M, x, 4b2103}
|
||||||
|
else if (any_sel) vshl{x, x, vcount{type{x}}-1}
|
||||||
|
else shuf{[4]u32, x, 4b1032}
|
||||||
|
}}
|
||||||
xo := x<<k | vrot1{x>>(64-k)}
|
xo := x<<k | vrot1{x>>(64-k)}
|
||||||
# Write result word given starting bits
|
# Write result word given starting bits
|
||||||
def step{b, c} = output{sub_carry{c - b, c & mr}}
|
def step{b, c} = output{sub_carry{c - b, c & mr}}
|
||||||
@ -689,7 +699,7 @@ def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
|
|||||||
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
||||||
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
|
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
|
||||||
def ww = width{W}
|
def ww = width{W}
|
||||||
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
|
def ew = if (any_sel) 8 else 32 # width of shuffle-able elements
|
||||||
def ne = ww/ew; def se = 64/ew
|
def ne = ww/ew; def se = 64/ew
|
||||||
def lanes = ww > 128
|
def lanes = ww > 128
|
||||||
def dup{v} = if (lanes) merge{v,v} else v
|
def dup{v} = if (lanes) merge{v,v} else v
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user