Mask off last k/bool vector properly

This commit is contained in:
Marshall Lochbaum 2024-08-07 17:38:57 -04:00
parent 1a4cada0cb
commit 696a23af9e

View File

@ -360,6 +360,26 @@ def tr_iota{...bs} = {
} }
def tr_iota{{...bs}} = tr_iota{...bs} def tr_iota{{...bs}} = tr_iota{...bs}
def get_boolvec_writer{V, r, rlen} = {
def vwords = width{V}/64
nw := cdiv{rlen, 64}
rv := *V~~r
re := rv + nw / vwords
last_res:V = V**0
def end = makelabel{}
def output{v:(V)} = {
last_res = v
if (rv==re) goto{end}
store{rv, 0, v}; ++rv
}
def flush{} = {
setlabel{end}
q := nw & (vwords-1)
if (q != 0) homMaskStoreF{rv, V~~maskOf{re_el{u64,V}, q}, last_res}
}
tup{output, flush}
}
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8 def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
oper // ({a,b}=>floor{a/b}) infix left 40 oper // ({a,b}=>floor{a/b}) infix left 40
def avx2 = hasarch{'AVX2'} def avx2 = hasarch{'AVX2'}
@ -369,20 +389,17 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
def mkV = make{V, .} def mkV = make{V, .}
def selH = sel{[16]u8, ., .} def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .} def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
nv := cdiv{rlen, width{V}}
def id{xv} = xv def id{xv} = xv
def {output, flush} = get_boolvec_writer{V, r, rlen}
def run24{x, proc_xv, exh} = { def run24{x, proc_xv, exh} = {
i:usz = 0; j:usz = 0 i:usz = 0; while (1) {
rv := *V~~r
def end = makelabel{}; while (j < nv) {
xv := proc_xv{load{*V~~(x+i)}}; ++i xv := proc_xv{load{*V~~(x+i)}}; ++i
# Store 1 or 2 result vectors # Store 1 or 2 result vectors
def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .} def getr = zip128{exh{xv}, exh{V~~(re_el{u16,V}~~xv>>4)}, .}
store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end} output{V~~getr{0}}
store{rv, j, V~~getr{1}}; ++j output{V~~getr{1}}
} }
setlabel{end}
} }
if (wv == 2) { if (wv == 2) {
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
@ -400,13 +417,14 @@ def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
def exh{x} = re_el{u16, V}~~tabr{x & m2} def exh{x} = re_el{u16, V}~~tabr{x & m2}
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh} run24{*(if (avx2) [2]u64 else u64)~~x, init, exh}
} else { # wv == 8 } else { # wv == 8
@for (r in *V~~r over i to nv) { i:usz = 0; while (1) {
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)} xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i
xv := if (avx2) pair{xh, xh} else xh xv := if (avx2) pair{xh, xh} else xh
xe := selH{xv, mkV{iV // 8}} xe := selH{xv, mkV{iV // 8}}
r = (xe & mkV{1 << (iV % 8)}) > V**0 output{(xe & mkV{1 << (iV % 8)}) > V**0}
} }
} }
flush{}
} }
# For odd numbers: # For odd numbers:
@ -420,8 +438,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
def iV = iota{vl} def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .} def mkV = make{V, .}; def selV = sel{V, ., .}
def W = [2]u64 def W = [2]u64
def {output, flush} = get_boolvec_writer{V, r, rlen}
nv := cdiv{rlen, width{V}}
# Within-byte transformation # Within-byte transformation
def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}} def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
@ -436,9 +453,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
# Cases are 3; 5 7; 9 11 13 15 # Cases are 3; 5 7; 9 11 13 15
if (wv < 4) { if (wv < 4) {
# 3: dedicated loop # 3: dedicated loop
i:usz = 0; j:usz = 0 i:usz = 0; while (1) {
rv := *V~~r
def end = makelabel{}; while (j < nv) {
# 01234567 to 05316427 on each byte # 01234567 to 05316427 on each byte
xv := get_perm_x{i}; ++i xv := get_perm_x{i}; ++i
# Overhang from previous 64-bit elements # Overhang from previous 64-bit elements
@ -453,7 +468,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
b := W~~(selV{xv, ind} & mask) b := W~~(selV{xv, ind} & mask)
r := V~~((b<<3) - b) r := V~~((b<<3) - b)
o := selV{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}} o := selV{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
store{rv, j, r|o}; ++j; if (jj<2 and j==nv) goto{end} output{r|o}
} }
def make3V{vs} = each{make{V,.}, split{vl, vs}} def make3V{vs} = each{make{V,.}, split{vl, vs}}
each{step, each{step,
@ -462,7 +477,6 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
make3V{8w2b001 << ((-8)*iota{3*vl} % 3)} make3V{8w2b001 << ((-8)*iota{3*vl} % 3)}
} }
} }
setlabel{end}
} else if (wv < 8) { } else if (wv < 8) {
# 5, 7: precompute constants, then shared loop # 5, 7: precompute constants, then shared loop
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{ {xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
@ -486,7 +500,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
if (wv == 5) set_consts{5} else set_consts{7} if (wv == 5) set_consts{5} else set_consts{7}
xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state
i:usz = 0; q:usz = 1 i:usz = 0; q:usz = 1
@for (r in *V~~r over j to nv) { while (1) {
--q; if (q == 0) { q = wv --q; if (q == 0) { q = wv
# Load and permute bytes # Load and permute bytes
xv = get_perm_x{i}; ++i xv = get_perm_x{i}; ++i
@ -504,8 +518,9 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh)) mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
} }
b := W~~(selV{xv, ind} & mask) b := W~~(selV{xv, ind} & mask)
r = V~~((b<<wv) - b) rv:= V~~((b<<wv) - b)
r |= xo & mkV{255 * (iV%8 == 0)} # overhang o := xo & mkV{255 * (iV%8 == 0)} # overhang
output{rv | o}
} }
} else { # wv < 16 } else { # wv < 16
# 9, 11, 13, 15: shared constant computation and loop # 9, 11, 13, 15: shared constant computation and loop
@ -521,7 +536,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
# State # State
xv:V = V**0; o:=W**0; ind:=xv; mask:=xv xv:V = V**0; o:=W**0; ind:=xv; mask:=xv
i:usz = 0; q:usz = 1 i:usz = 0; q:usz = 1
@for (r in *V~~r over j to nv) { while (1) {
--q; if (q == 0) { q = wv --q; if (q == 0) { q = wv
xv = get_perm_x{i}; ++i xv = get_perm_x{i}; ++i
ind = ind0 ind = ind0
@ -532,11 +547,12 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
} }
# Handle overhang here; won't fit in a single vector # Handle overhang here; won't fit in a single vector
b := W~~(selV{xv, ind} & mask) b := W~~(selV{xv, ind} & mask)
r = V~~((b<<wv) - b) rv:= V~~((b<<wv) - b)
po:= o; o = b>>(64-wv) po:= o; o = b>>(64-wv)
ro:= [4]u32~~vshl{po, o, 1} ro:= [4]u32~~vshl{po, o, 1}
r |= V~~(ro + (ro > [4]u32**0)) output{rv | V~~(ro + (ro > [4]u32**0))}
} }
} }
flush{}
} }
export{'si_constrep_bool', rep_const_bool{}} export{'si_constrep_bool', rep_const_bool{}}