NEON k/bool support

This commit is contained in:
Marshall Lochbaum 2024-08-14 14:08:35 -04:00
parent b801c7c186
commit b2566c8f3a

View File

@ -289,6 +289,12 @@ exportT{'si_constrep', each{rep_const, dat_types}}
# Constant replicate on boolean
def any_sel = hasarch{'SSSE3'} or hasarch{'AARCH64'}
if_inline (hasarch{'AARCH64'}) {
def __shl{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,b}
def __shr{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,-b}
}
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
assert{wv >= 2}; assert{wv <= 64}
nw := cdiv{rlen, 64}
@ -297,7 +303,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
wf := wv>>p
if (wf == 1) {
rep_const_bool_div8{wv, x, r, nw}
} else if (hasarch{'SSSE3'} and p == 3 and wv <= 8*select{basic_rep, -1}) {
} else if (any_sel and p == 3 and wv <= 8*select{basic_rep, -1}) {
# (higher wv double-factors, which doesn't work with in-place pointers)
tlen := rlen / wf
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
@ -306,7 +312,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
} else {
tlen := rlen >> p
wq := usz~~1 << p
if (p == 1 and (not hasarch{'SSSE3'} or wv>=24)) {
if (p == 1 and (not any_sel or wv>=24)) {
# Expanding odd second is faster
tlen = rlen / wf
t:=wf; wf=wq; wq=t
@ -383,7 +389,7 @@ def get_boolvec_writer{T=(u64), r:*T, nw} = {
tup{output, {}=>{}, flush}
}
def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
def rep_const_bool_div8{wv, x, r, nw if has_simd} = {
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16
def V = [vl]u8
@ -414,7 +420,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
output{(xe & mkV{1 << (iV % 8)}) > V**0}
}
}
if (not hasarch{'SSSE3'}) {
if (not any_sel) {
if (wv == 2) {
run24{*V~~x, {xv} => {
def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<<l) }
@ -435,7 +441,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
def z{x} = zip{x,x,0}
run8{{x} => z{z{z{x}}}}
}
} else { # hasarch{'SSSE3'}
} else { # any_sel
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
def id{xv} = xv
@ -502,9 +508,9 @@ def swap_elts{x:V=[_]_, el_bytes} = {
def rev = i + (l - 2*(i&l))
sel_imm{[n]ty_u{wd*8}, x, rev}
}
if (el_bytes >= 4) {
if (hasarch{'SSE2'} and el_bytes >= 4) {
selx{4, max{4, el_bytes/2}}
} else if (hasarch{'SSSE3'}) {
} else if (any_sel) {
selx{16, 1}
} else {
def sh = el_bytes*8
@ -515,7 +521,8 @@ def swap_elts{x:V=[_]_, el_bytes} = {
# Extract swap functions from modperm_dat
def extract_modperm_mask{full:W, lane, l} = {
def f = l == 128 # Use full width
def w = if (l<32) 8 else if (f) 64 else 32 # Shuffle element width
def w = if (l<32 or not hasarch{'SSE2'}) 8
else if (f) 64 else 32 # Shuffle element width
def n = (if (f) 256 else 128) / w # and number of elements
def o = l / w # Extract [o,2*o)
def i = o + iota{n}%o
@ -556,9 +563,9 @@ def proc_mod_dat{swap_data:W} = {
}
def {partwidth, partperm} = {
def sh{w, d} = tup{w, get_shiftperm{d, w}}
if (ww==64) sh{64, swap_data}
else if (not hasarch{'SSSE3'}) sh{32, shuf{[4]u32, swap_data, 4b0000}}
else tup{8, get_byteperm{}}
if (ww==64) sh{64, swap_data}
else if (not any_sel) sh{32, shuf{[4]u32, swap_data, 4b0000}}
else tup{8, get_byteperm{}}
}
# Fill in higher steps
def get_mod_permuter{} = {
@ -574,7 +581,7 @@ def proc_mod_dat{swap_data:W} = {
def rep_const_bool_odd{k, x, r, nw} = {
def avx2 = hasarch{'AVX2'}
def W = if (hasarch{'SSE2'}) [if (avx2) 4 else 2]u64 else u64
def W = if (has_simd) [if (avx2) 4 else 2]u64 else u64
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
xp := *W~~x
@ -586,7 +593,7 @@ def rep_const_bool_odd{k, x, r, nw} = {
swap_data := load{swtab, k>>1}
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
def sp_max = if (hasarch{'SSSE3'}) 8 else 4
def sp_max = if (any_sel) 8 else 4
if (k < sp_max) {
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
} else {
@ -639,7 +646,7 @@ def rep_const_bool_odd_mask4{
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
def sub_carry{a, c} = match (M) {
{[l](u64)} => {
ca := if (hasarch{'SSE4.2'}) { def S = [l]i64; S~~c > S**0 }
ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 }
else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} }
a + M~~ca
}
@ -648,8 +655,11 @@ def rep_const_bool_odd_mask4{
while (1) {
x:M = get_modperm_x{}
def vrot1 = ifvec{{x} => if (w128{M}) shuf{[4]u32, x, 4b1032}
else shuf{M, x, 4b2103}}
def vrot1 = ifvec{{x} => {
if (w256{M}) shuf{M, x, 4b2103}
else if (any_sel) vshl{x, x, vcount{type{x}}-1}
else shuf{[4]u32, x, 4b1032}
}}
xo := x<<k | vrot1{x>>(64-k)}
# Write result word given starting bits
def step{b, c} = output{sub_carry{c - b, c & mr}}
@ -689,7 +699,7 @@ def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
def ww = width{W}
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
def ew = if (any_sel) 8 else 32 # width of shuffle-able elements
def ne = ww/ew; def se = 64/ew
def lanes = ww > 128
def dup{v} = if (lanes) merge{v,v} else v