AVX2 support in rep_const_bool_ssse3_div8
This commit is contained in:
parent
b4d84041bc
commit
1a4cada0cb
@ -352,13 +352,25 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
1
|
||||
}
|
||||
|
||||
# Generalized flat transpose of iota{1<<length{bs}}
|
||||
# select{tr_iota{bs}, x} sends bit i of x to position select{bs, i}
|
||||
def tr_iota{...bs} = {
|
||||
def axes = each{tup{0,.}, 1<<bs}
|
||||
fold{flat_table{|,...}, reverse{axes}}
|
||||
}
|
||||
def tr_iota{{...bs}} = tr_iota{...bs}
|
||||
|
||||
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
|
||||
def vl = 16; def V = [vl]u8
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
def vl = if (avx2) 32 else 16
|
||||
def V = [vl]u8
|
||||
def iV = iota{vl}
|
||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||
def mkV = make{V, .}
|
||||
def selH = sel{[16]u8, ., .}
|
||||
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
|
||||
nv := cdiv{rlen, width{V}}
|
||||
def id{xv} = xv
|
||||
|
||||
def run24{x, proc_xv, exh} = {
|
||||
i:usz = 0; j:usz = 0
|
||||
@ -366,34 +378,37 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||
def end = makelabel{}; while (j < nv) {
|
||||
xv := proc_xv{load{*V~~(x+i)}}; ++i
|
||||
# Store 1 or 2 result vectors
|
||||
def getr = zip{exh{xv}, exh{V~~([8]u16~~xv>>4)}, .}
|
||||
def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .}
|
||||
store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end}
|
||||
store{rv, j, V~~getr{1}}; ++j
|
||||
}
|
||||
setlabel{end}
|
||||
}
|
||||
if (wv == 2) {
|
||||
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
|
||||
# Expander for half byte
|
||||
tabr := mkV{bv{2*iota{4}} * 2b11}
|
||||
def tabr = makeTab{tr_iota{2*iota{4}} * 2b11}
|
||||
m4 := V**0xf
|
||||
def exh{x} = selV{tabr, x & m4}
|
||||
run24{*V~~x, {xv}=>xv, exh}
|
||||
run24{*V~~x, init, {x} => tabr{x & m4}}
|
||||
} else if (wv == 4) {
|
||||
# Double each byte
|
||||
dup := mkV{iV//2}
|
||||
# Unzip 32-bit elements (result lanes) across AVX2 lanes
|
||||
def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id
|
||||
def init{xv} = { u:=pre{xv}; zip128{u,u,0} }
|
||||
# Expander for two bits in either bottom or next-to-bottom position
|
||||
tabr := mkV{bv{tup{0,4,0,4}} * 2b1111}
|
||||
def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111}
|
||||
m2 := mkV{2b11 << (2*(iV%2))}
|
||||
def exh{x} = re_el{u16, V}~~selV{tabr, x & m2}
|
||||
run24{*u64~~x, selV{., dup}, exh}
|
||||
def exh{x} = re_el{u16, V}~~tabr{x & m2}
|
||||
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh}
|
||||
} else { # wv == 8
|
||||
@for (r in *V~~r over i to nv) {
|
||||
xv := load{*V~~(*u16~~x + i)}
|
||||
xe := selV{xv, mkV{iV//8}}
|
||||
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}
|
||||
xv := if (avx2) pair{xh, xh} else xh
|
||||
xe := selH{xv, mkV{iV // 8}}
|
||||
r = (xe & mkV{1 << (iV % 8)}) > V**0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# For odd numbers:
|
||||
# - permute each byte sending bit i to position k*i % 8
|
||||
# - replicate each byte by k, making position k*i contain bit i
|
||||
@ -401,7 +416,6 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||
# - ...except where it crosses words; handle this overhang separately
|
||||
def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
|
||||
def vl = 16; def V = [vl]u8
|
||||
def iV = iota{vl}
|
||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||
@ -410,7 +424,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
nv := cdiv{rlen, width{V}}
|
||||
|
||||
# Within-byte transformation
|
||||
def get_ttab{k} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
|
||||
def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
|
||||
ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
|
||||
{t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
|
||||
m4 := V**0xf
|
||||
|
||||
Loading…
Reference in New Issue
Block a user