Mask off last k/bool vector properly

This commit is contained in:
Marshall Lochbaum 2024-08-07 17:38:57 -04:00
parent 1a4cada0cb
commit 696a23af9e

View File

@ -360,6 +360,26 @@ def tr_iota{...bs} = {
}
def tr_iota{{...bs}} = tr_iota{...bs}
def get_boolvec_writer{V, r, rlen} = {
def vwords = width{V}/64
nw := cdiv{rlen, 64}
rv := *V~~r
re := rv + nw / vwords
last_res:V = V**0
def end = makelabel{}
def output{v:(V)} = {
last_res = v
if (rv==re) goto{end}
store{rv, 0, v}; ++rv
}
def flush{} = {
setlabel{end}
q := nw & (vwords-1)
if (q != 0) homMaskStoreF{rv, V~~maskOf{re_el{u64,V}, q}, last_res}
}
tup{output, flush}
}
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
oper // ({a,b}=>floor{a/b}) infix left 40
def avx2 = hasarch{'AVX2'}
@ -369,20 +389,17 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
def mkV = make{V, .}
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
nv := cdiv{rlen, width{V}}
def id{xv} = xv
def {output, flush} = get_boolvec_writer{V, r, rlen}
def run24{x, proc_xv, exh} = {
i:usz = 0; j:usz = 0
rv := *V~~r
def end = makelabel{}; while (j < nv) {
i:usz = 0; while (1) {
xv := proc_xv{load{*V~~(x+i)}}; ++i
# Store 1 or 2 result vectors
def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .}
store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end}
store{rv, j, V~~getr{1}}; ++j
output{V~~getr{0}}
output{V~~getr{1}}
}
setlabel{end}
}
if (wv == 2) {
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
@ -400,13 +417,14 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
def exh{x} = re_el{u16, V}~~tabr{x & m2}
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh}
} else { # wv == 8
@for (r in *V~~r over i to nv) {
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}
i:usz = 0; while (1) {
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i
xv := if (avx2) pair{xh, xh} else xh
xe := selH{xv, mkV{iV // 8}}
r = (xe & mkV{1 << (iV % 8)}) > V**0
output{(xe & mkV{1 << (iV % 8)}) > V**0}
}
}
flush{}
}
# For odd numbers:
@ -420,8 +438,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
def W = [2]u64
nv := cdiv{rlen, width{V}}
def {output, flush} = get_boolvec_writer{V, r, rlen}
# Within-byte transformation
def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
@ -436,9 +453,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
# Cases are 3; 5 7; 9 11 13 15
if (wv < 4) {
# 3: dedicated loop
i:usz = 0; j:usz = 0
rv := *V~~r
def end = makelabel{}; while (j < nv) {
i:usz = 0; while (1) {
# 01234567 to 05316427 on each byte
xv := get_perm_x{i}; ++i
# Overhang from previous 64-bit elements
@ -453,7 +468,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
b := W~~(selV{xv, ind} & mask)
r := V~~((b<<3) - b)
o := selV{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
store{rv, j, r|o}; ++j; if (jj<2 and j==nv) goto{end}
output{r|o}
}
def make3V{vs} = each{make{V,.}, split{vl, vs}}
each{step,
@ -462,7 +477,6 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
make3V{8w2b001 << ((-8)*iota{3*vl} % 3)}
}
}
setlabel{end}
} else if (wv < 8) {
# 5, 7: precompute constants, then shared loop
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
@ -486,7 +500,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
if (wv == 5) set_consts{5} else set_consts{7}
xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state
i:usz = 0; q:usz = 1
@for (r in *V~~r over j to nv) {
while (1) {
--q; if (q == 0) { q = wv
# Load and permute bytes
xv = get_perm_x{i}; ++i
@ -504,8 +518,9 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
}
b := W~~(selV{xv, ind} & mask)
r = V~~((b<<wv) - b)
r |= xo & mkV{255 * (iV%8 == 0)} # overhang
rv:= V~~((b<<wv) - b)
o := xo & mkV{255 * (iV%8 == 0)} # overhang
output{rv | o}
}
} else { # wv < 16
# 9, 11, 13, 15: shared constant computation and loop
@ -521,7 +536,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
# State
xv:V = V**0; o:=W**0; ind:=xv; mask:=xv
i:usz = 0; q:usz = 1
@for (r in *V~~r over j to nv) {
while (1) {
--q; if (q == 0) { q = wv
xv = get_perm_x{i}; ++i
ind = ind0
@ -532,11 +547,12 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
}
# Handle overhang here; won't fit in a single vector
b := W~~(selV{xv, ind} & mask)
r = V~~((b<<wv) - b)
rv:= V~~((b<<wv) - b)
po:= o; o = b>>(64-wv)
ro:= [4]u32~~vshl{po, o, 1}
r |= V~~(ro + (ro > [4]u32**0))
output{rv | V~~(ro + (ro > [4]u32**0))}
}
}
flush{}
}
export{'si_constrep_bool', rep_const_bool{}}