SSE2 k/bool where k divides 8 implementations

This commit is contained in:
Marshall Lochbaum 2024-08-13 15:59:48 -04:00
parent ee27717c97
commit 048529740b

View File

@ -300,7 +300,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
} else {
tlen := rlen>>p
wq := usz~~1<<p
if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (p == 1 and wv>=24)) {
if (p == 1 and (not hasarch{'SSSE3'} or wv>=24)) {
# Expanding odd second is faster
tlen = rlen / wf
t:=wf; wf=wq; wq=t
@ -378,49 +378,81 @@ def get_boolvec_writer{T=(u64), r:*T, nw} = {
tup{output, {}=>{}, flush}
}
def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = {
def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSE2'}} = {
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16
def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
def id{xv} = xv
def {output, done, flush} = get_boolvec_writer{V, r, nw}
# Not the same as a 1-byte shift but extra bits are always masked away
def H = re_el{u16, V}
def __shl{x:(V), a} = V~~(H~~x << a)
def __shr{x:(V), a} = V~~(H~~x >> a)
def run24{x, proc_xv, exh} = {
def {output, done, flush} = get_boolvec_writer{V, r, nw}
def run24{x, get_halves} = {
i:usz = 0; while (1) {
xv := proc_xv{load{*V~~(x+i)}}; ++i
# Store 1 or 2 result vectors
def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .}
xv := load{*V~~(x+i)}; ++i
def getr = zip128{...get_halves{xv}, .}
output{V~~getr{0}}
output{V~~getr{1}}
}
}
if (wv == 2) {
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
# Expander for half byte
def tabr = makeTab{tr_iota{2*iota{4}} * 2b11}
m4 := V**0xf
run24{*V~~x, init, {x} => tabr{x & m4}}
} else if (wv == 4) {
# Unzip 32-bit elements (result lanes) across AVX2 lanes
def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id
def init{xv} = { u:=pre{xv}; zip128{u,u,0} }
# Expander for two bits in either bottom or next-to-bottom position
def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111}
m2 := mkV{2b11 << (2*(iV%2))}
def exh{x} = re_el{u16, V}~~tabr{x & m2}
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh}
} else { # wv == 8
def run24{x, pre, exh, sh} = run24{x,
{xv} => { p := pre{xv}; tup{exh{p}, exh{p>>sh}} }
}
def run8{rep_bytes} = {
i:usz = 0; while (1) {
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i
xv := if (avx2) pair{xh, xh} else xh
xe := selH{xv, mkV{iV // 8}}
xe := rep_bytes{xv}
output{(xe & mkV{1 << (iV % 8)}) > V**0}
}
}
if (not hasarch{'SSSE3'}) {
if (wv == 2) {
run24{*V~~x, {xv} => {
def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<<l) }
a := swap{swap{xv, 2, 0x0c}, 1, 0x22}
m := V**0x55
h0 := a & m; o0 := h0 | h0<<1
h1 := a &~ m; o1 := h1 | h1>>1
tup{o0, o1}
}}
} else if (wv == 4) {
def pre{xv} = {
a := xv ^ ((xv & V**0x55) << 1)
e := zip128{a, a>>4, 0}
}
def out{h} = (-(h & V**1)) ^ (-(h<<3 & V**0x10))
run24{*u64~~x, pre, out, 2}
} else { # wv == 8
def z{x} = zip{x,x,0}
run8{{x} => z{z{z{x}}}}
}
} else { # hasarch{'SSSE3'}
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
def id{xv} = xv
if (wv == 2) {
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
# Expander for half byte
def tabr = makeTab{tr_iota{2*iota{4}} * 2b11}
m4 := V**0xf
run24{*V~~x, init, {x} => tabr{x & m4}, 4}
} else if (wv == 4) {
# Unzip 32-bit elements (result lanes) across AVX2 lanes
def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id
def init{xv} = { u:=pre{xv}; zip128{u,u,0} }
# Expander for two bits in either bottom or next-to-bottom position
def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111}
m2 := mkV{2b11 << (2*(iV%2))}
def exh{x} = re_el{u16, V}~~tabr{x & m2}
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh, 4}
} else { # wv == 8
run8{selH{., mkV{iV // 8}}}
}
}
flush{}
}