Mask off last k/bool vector properly
This commit is contained in:
parent
1a4cada0cb
commit
696a23af9e
@ -360,6 +360,26 @@ def tr_iota{...bs} = {
|
||||
}
|
||||
def tr_iota{{...bs}} = tr_iota{...bs}
|
||||
|
||||
def get_boolvec_writer{V, r, rlen} = {
|
||||
def vwords = width{V}/64
|
||||
nw := cdiv{rlen, 64}
|
||||
rv := *V~~r
|
||||
re := rv + nw / vwords
|
||||
last_res:V = V**0
|
||||
def end = makelabel{}
|
||||
def output{v:(V)} = {
|
||||
last_res = v
|
||||
if (rv==re) goto{end}
|
||||
store{rv, 0, v}; ++rv
|
||||
}
|
||||
def flush{} = {
|
||||
setlabel{end}
|
||||
q := nw & (vwords-1)
|
||||
if (q != 0) homMaskStoreF{rv, V~~maskOf{re_el{u64,V}, q}, last_res}
|
||||
}
|
||||
tup{output, flush}
|
||||
}
|
||||
|
||||
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
@ -369,20 +389,17 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||
def mkV = make{V, .}
|
||||
def selH = sel{[16]u8, ., .}
|
||||
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
|
||||
nv := cdiv{rlen, width{V}}
|
||||
def id{xv} = xv
|
||||
def {output, flush} = get_boolvec_writer{V, r, rlen}
|
||||
|
||||
def run24{x, proc_xv, exh} = {
|
||||
i:usz = 0; j:usz = 0
|
||||
rv := *V~~r
|
||||
def end = makelabel{}; while (j < nv) {
|
||||
i:usz = 0; while (1) {
|
||||
xv := proc_xv{load{*V~~(x+i)}}; ++i
|
||||
# Store 1 or 2 result vectors
|
||||
def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .}
|
||||
store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end}
|
||||
store{rv, j, V~~getr{1}}; ++j
|
||||
output{V~~getr{0}}
|
||||
output{V~~getr{1}}
|
||||
}
|
||||
setlabel{end}
|
||||
}
|
||||
if (wv == 2) {
|
||||
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
|
||||
@ -400,13 +417,14 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||
def exh{x} = re_el{u16, V}~~tabr{x & m2}
|
||||
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh}
|
||||
} else { # wv == 8
|
||||
@for (r in *V~~r over i to nv) {
|
||||
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}
|
||||
i:usz = 0; while (1) {
|
||||
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i
|
||||
xv := if (avx2) pair{xh, xh} else xh
|
||||
xe := selH{xv, mkV{iV // 8}}
|
||||
r = (xe & mkV{1 << (iV % 8)}) > V**0
|
||||
output{(xe & mkV{1 << (iV % 8)}) > V**0}
|
||||
}
|
||||
}
|
||||
flush{}
|
||||
}
|
||||
|
||||
# For odd numbers:
|
||||
@ -420,8 +438,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
def iV = iota{vl}
|
||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||
def W = [2]u64
|
||||
|
||||
nv := cdiv{rlen, width{V}}
|
||||
def {output, flush} = get_boolvec_writer{V, r, rlen}
|
||||
|
||||
# Within-byte transformation
|
||||
def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
|
||||
@ -436,9 +453,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
# Cases are 3; 5 7; 9 11 13 15
|
||||
if (wv < 4) {
|
||||
# 3: dedicated loop
|
||||
i:usz = 0; j:usz = 0
|
||||
rv := *V~~r
|
||||
def end = makelabel{}; while (j < nv) {
|
||||
i:usz = 0; while (1) {
|
||||
# 01234567 to 05316427 on each byte
|
||||
xv := get_perm_x{i}; ++i
|
||||
# Overhang from previous 64-bit elements
|
||||
@ -453,7 +468,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
b := W~~(selV{xv, ind} & mask)
|
||||
r := V~~((b<<3) - b)
|
||||
o := selV{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
|
||||
store{rv, j, r|o}; ++j; if (jj<2 and j==nv) goto{end}
|
||||
output{r|o}
|
||||
}
|
||||
def make3V{vs} = each{make{V,.}, split{vl, vs}}
|
||||
each{step,
|
||||
@ -462,7 +477,6 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
make3V{8w2b001 << ((-8)*iota{3*vl} % 3)}
|
||||
}
|
||||
}
|
||||
setlabel{end}
|
||||
} else if (wv < 8) {
|
||||
# 5, 7: precompute constants, then shared loop
|
||||
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
|
||||
@ -486,7 +500,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
if (wv == 5) set_consts{5} else set_consts{7}
|
||||
xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state
|
||||
i:usz = 0; q:usz = 1
|
||||
@for (r in *V~~r over j to nv) {
|
||||
while (1) {
|
||||
--q; if (q == 0) { q = wv
|
||||
# Load and permute bytes
|
||||
xv = get_perm_x{i}; ++i
|
||||
@ -504,8 +518,9 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
|
||||
}
|
||||
b := W~~(selV{xv, ind} & mask)
|
||||
r = V~~((b<<wv) - b)
|
||||
r |= xo & mkV{255 * (iV%8 == 0)} # overhang
|
||||
rv:= V~~((b<<wv) - b)
|
||||
o := xo & mkV{255 * (iV%8 == 0)} # overhang
|
||||
output{rv | o}
|
||||
}
|
||||
} else { # wv < 16
|
||||
# 9, 11, 13, 15: shared constant computation and loop
|
||||
@ -521,7 +536,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
# State
|
||||
xv:V = V**0; o:=W**0; ind:=xv; mask:=xv
|
||||
i:usz = 0; q:usz = 1
|
||||
@for (r in *V~~r over j to nv) {
|
||||
while (1) {
|
||||
--q; if (q == 0) { q = wv
|
||||
xv = get_perm_x{i}; ++i
|
||||
ind = ind0
|
||||
@ -532,11 +547,12 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
}
|
||||
# Handle overhang here; won't fit in a single vector
|
||||
b := W~~(selV{xv, ind} & mask)
|
||||
r = V~~((b<<wv) - b)
|
||||
rv:= V~~((b<<wv) - b)
|
||||
po:= o; o = b>>(64-wv)
|
||||
ro:= [4]u32~~vshl{po, o, 1}
|
||||
r |= V~~(ro + (ro > [4]u32**0))
|
||||
output{rv | V~~(ro + (ro > [4]u32**0))}
|
||||
}
|
||||
}
|
||||
flush{}
|
||||
}
|
||||
export{'si_constrep_bool', rep_const_bool{}}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user