From 1a4cada0cbeeeee15021ab27ad2b5c4e053e1ccf Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Wed, 7 Aug 2024 10:35:21 -0400 Subject: [PATCH] AVX2 support in rep_const_bool_ssse3_div8 --- src/singeli/src/replicate.singeli | 46 ++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index 5a29cf96..3b4aa0f7 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -352,13 +352,25 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { 1 } +# Generalized flat transpose of iota{1<floor{a/b}) infix left 40 - def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<>4)}, .} + def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .} store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end} store{rv, j, V~~getr{1}}; ++j } setlabel{end} } if (wv == 2) { + def init = if (avx2) shuf{[4]u64, ., 4b3120} else id # Expander for half byte - tabr := mkV{bv{2*iota{4}} * 2b11} + def tabr = makeTab{tr_iota{2*iota{4}} * 2b11} m4 := V**0xf - def exh{x} = selV{tabr, x & m4} - run24{*V~~x, {xv}=>xv, exh} + run24{*V~~x, init, {x} => tabr{x & m4}} } else if (wv == 4) { - # Double each byte - dup := mkV{iV//2} + # Unzip 32-bit elements (result lanes) across AVX2 lanes + def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id + def init{xv} = { u:=pre{xv}; zip128{u,u,0} } # Expander for two bits in either bottom or next-to-bottom position - tabr := mkV{bv{tup{0,4,0,4}} * 2b1111} + def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111} m2 := mkV{2b11 << (2*(iV%2))} - def exh{x} = re_el{u16, V}~~selV{tabr, x & m2} - run24{*u64~~x, selV{., dup}, exh} + def exh{x} = re_el{u16, V}~~tabr{x & m2} + run24{*(if (avx2) [2]u64 else u64)~~x, init, exh} } else { # wv == 8 @for (r in *V~~r over i to nv) { - xv := load{*V~~(*u16~~x + i)} - xe := selV{xv, mkV{iV//8}} + xh := load{*[16]u8~~(*ty_u{vl}~~x + i)} + xv := if (avx2) pair{xh, xh} else xh + xe := selH{xv, mkV{iV // 8}} r = (xe & mkV{1 << (iV % 8)}) > V**0 } } } + # For odd numbers: # - permute each byte sending bit i to position k*i % 8 # - replicate each byte by k, making position k*i contain bit i @@ -401,7 +416,6 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8 # - ...except where it crosses words; handle this overhang separately def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15 oper // ({a,b}=>floor{a/b}) infix left 40 - def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1< mkV{bv{is}}, split{4, k*iota{8} % 8}} + def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}} ttab:*V = join{each{get_ttab, 2*iota{4} + 1}} {t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}} m4 := V**0xf