From b2566c8f3ad3e51f9cf54f9b37542b8baad2e725 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Wed, 14 Aug 2024 14:08:35 -0400 Subject: [PATCH] NEON k/bool support --- src/singeli/src/replicate.singeli | 44 +++++++++++++++++++------------ 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index 6b65d7dc..8b80c07a 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -289,6 +289,12 @@ exportT{'si_constrep', each{rep_const, dat_types}} # Constant replicate on boolean +def any_sel = hasarch{'SSSE3'} or hasarch{'AARCH64'} +if_inline (hasarch{'AARCH64'}) { + def __shl{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,b} + def __shr{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,-b} +} + fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = { assert{wv >= 2}; assert{wv <= 64} nw := cdiv{rlen, 64} @@ -297,7 +303,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = { wf := wv>>p if (wf == 1) { rep_const_bool_div8{wv, x, r, nw} - } else if (hasarch{'SSSE3'} and p == 3 and wv <= 8*select{basic_rep, -1}) { + } else if (any_sel and p == 3 and wv <= 8*select{basic_rep, -1}) { # (higher wv double-factors, which doesn't work with in-place pointers) tlen := rlen / wf t := r + cdiv{rlen, 64} - cdiv{tlen, 64} @@ -306,7 +312,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = { } else { tlen := rlen >> p wq := usz~~1 << p - if (p == 1 and (not hasarch{'SSSE3'} or wv>=24)) { + if (p == 1 and (not any_sel or wv>=24)) { # Expanding odd second is faster tlen = rlen / wf t:=wf; wf=wq; wq=t @@ -383,7 +389,7 @@ def get_boolvec_writer{T=(u64), r:*T, nw} = { tup{output, {}=>{}, flush} } -def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = { +def rep_const_bool_div8{wv, x, r, nw if has_simd} = { def avx2 = hasarch{'AVX2'} def vl = if (avx2) 32 else 16 def V = [vl]u8 @@ -414,7 +420,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = { output{(xe & mkV{1 << (iV % 8)}) > V**0} } } - if (not hasarch{'SSSE3'}) { + if (not any_sel) { if (wv == 2) { run24{*V~~x, {xv} => { def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d< z{z{z{x}}}} } - } else { # hasarch{'SSSE3'} + } else { # any_sel def selH = sel{[16]u8, ., .} def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .} def id{xv} = xv @@ -502,9 +508,9 @@ def swap_elts{x:V=[_]_, el_bytes} = { def rev = i + (l - 2*(i&l)) sel_imm{[n]ty_u{wd*8}, x, rev} } - if (el_bytes >= 4) { + if (hasarch{'SSE2'} and el_bytes >= 4) { selx{4, max{4, el_bytes/2}} - } else if (hasarch{'SSSE3'}) { + } else if (any_sel) { selx{16, 1} } else { def sh = el_bytes*8 @@ -515,7 +521,8 @@ def swap_elts{x:V=[_]_, el_bytes} = { # Extract swap functions from modperm_dat def extract_modperm_mask{full:W, lane, l} = { def f = l == 128 # Use full width - def w = if (l<32) 8 else if (f) 64 else 32 # Shuffle element width + def w = if (l<32 or not hasarch{'SSE2'}) 8 + else if (f) 64 else 32 # Shuffle element width def n = (if (f) 256 else 128) / w # and number of elements def o = l / w # Extract [o,2*o) def i = o + iota{n}%o @@ -556,9 +563,9 @@ def proc_mod_dat{swap_data:W} = { } def {partwidth, partperm} = { def sh{w, d} = tup{w, get_shiftperm{d, w}} - if (ww==64) sh{64, swap_data} - else if (not hasarch{'SSSE3'}) sh{32, shuf{[4]u32, swap_data, 4b0000}} - else tup{8, get_byteperm{}} + if (ww==64) sh{64, swap_data} + else if (not any_sel) sh{32, shuf{[4]u32, swap_data, 4b0000}} + else tup{8, get_byteperm{}} } # Fill in higher steps def get_mod_permuter{} = { @@ -574,7 +581,7 @@ def proc_mod_dat{swap_data:W} = { def rep_const_bool_odd{k, x, r, nw} = { def avx2 = hasarch{'AVX2'} - def W = if (hasarch{'SSE2'}) [if (avx2) 4 else 2]u64 else u64 + def W = if (has_simd) [if (avx2) 4 else 2]u64 else u64 def {output, check_done, flush} = get_boolvec_writer{W, r, nw} xp := *W~~x @@ -586,7 +593,7 @@ def rep_const_bool_odd{k, x, r, nw} = { swap_data := load{swtab, k>>1} def {partperm, get_full_permute} = proc_mod_dat{swap_data} - def sp_max = if (hasarch{'SSSE3'}) 8 else 4 + def sp_max = if (any_sel) 8 else 4 if (k < sp_max) { rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output} } else { @@ -639,7 +646,7 @@ def rep_const_bool_odd_mask4{ mr := scal{u64~~1< { - ca := if (hasarch{'SSE4.2'}) { def S = [l]i64; S~~c > S**0 } + ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 } else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} } a + M~~ca } @@ -648,8 +655,11 @@ def rep_const_bool_odd_mask4{ while (1) { x:M = get_modperm_x{} - def vrot1 = ifvec{{x} => if (w128{M}) shuf{[4]u32, x, 4b1032} - else shuf{M, x, 4b2103}} + def vrot1 = ifvec{{x} => { + if (w256{M}) shuf{M, x, 4b2103} + else if (any_sel) vshl{x, x, vcount{type{x}}-1} + else shuf{[4]u32, x, 4b1032} + }} xo := x<>(64-k)} # Write result word given starting bits def step{b, c} = output{sub_carry{c - b, c & mr}} @@ -689,7 +699,7 @@ def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = { def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = { def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v } def ww = width{W} - def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements + def ew = if (any_sel) 8 else 32 # width of shuffle-able elements def ne = ww/ew; def se = 64/ew def lanes = ww > 128 def dup{v} = if (lanes) merge{v,v} else v