AVX2 support in rep_const_bool_ssse3_div8

This commit is contained in:
Marshall Lochbaum 2024-08-07 10:35:21 -04:00
parent b4d84041bc
commit 1a4cada0cb

View File

@ -352,13 +352,25 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
1
}
# Generalized flat transpose of iota{1<<length{bs}}
# select{tr_iota{bs}, x} sends bit i of x to position select{bs, i}
def tr_iota{...bs} = {
def axes = each{tup{0,.}, 1<<bs}
fold{flat_table{|,...}, reverse{axes}}
}
def tr_iota{{...bs}} = tr_iota{...bs}
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
oper // ({a,b}=>floor{a/b}) infix left 40
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
def vl = 16; def V = [vl]u8
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16
def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
def mkV = make{V, .}
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
nv := cdiv{rlen, width{V}}
def id{xv} = xv
def run24{x, proc_xv, exh} = {
i:usz = 0; j:usz = 0
@ -366,34 +378,37 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
def end = makelabel{}; while (j < nv) {
xv := proc_xv{load{*V~~(x+i)}}; ++i
# Store 1 or 2 result vectors
def getr = zip{exh{xv}, exh{V~~([8]u16~~xv>>4)}, .}
def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .}
store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end}
store{rv, j, V~~getr{1}}; ++j
}
setlabel{end}
}
if (wv == 2) {
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
# Expander for half byte
tabr := mkV{bv{2*iota{4}} * 2b11}
def tabr = makeTab{tr_iota{2*iota{4}} * 2b11}
m4 := V**0xf
def exh{x} = selV{tabr, x & m4}
run24{*V~~x, {xv}=>xv, exh}
run24{*V~~x, init, {x} => tabr{x & m4}}
} else if (wv == 4) {
# Double each byte
dup := mkV{iV//2}
# Unzip 32-bit elements (result lanes) across AVX2 lanes
def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id
def init{xv} = { u:=pre{xv}; zip128{u,u,0} }
# Expander for two bits in either bottom or next-to-bottom position
tabr := mkV{bv{tup{0,4,0,4}} * 2b1111}
def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111}
m2 := mkV{2b11 << (2*(iV%2))}
def exh{x} = re_el{u16, V}~~selV{tabr, x & m2}
run24{*u64~~x, selV{., dup}, exh}
def exh{x} = re_el{u16, V}~~tabr{x & m2}
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh}
} else { # wv == 8
@for (r in *V~~r over i to nv) {
xv := load{*V~~(*u16~~x + i)}
xe := selV{xv, mkV{iV//8}}
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}
xv := if (avx2) pair{xh, xh} else xh
xe := selH{xv, mkV{iV // 8}}
r = (xe & mkV{1 << (iV % 8)}) > V**0
}
}
}
# For odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
@ -401,7 +416,6 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
# - ...except where it crosses words; handle this overhang separately
def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
oper // ({a,b}=>floor{a/b}) infix left 40
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
def vl = 16; def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
@ -410,7 +424,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
nv := cdiv{rlen, width{V}}
# Within-byte transformation
def get_ttab{k} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
{t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
m4 := V**0xf