Move to dedicated k/bool loops for k=5,7 and at least SSSE3

This commit is contained in:
Marshall Lochbaum 2024-08-14 10:20:21 -04:00
parent b2758d355c
commit b801c7c186

View File

@ -586,7 +586,7 @@ def rep_const_bool_odd{k, x, r, nw} = {
swap_data := load{swtab, k>>1}
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
def sp_max = if (hasarch{'SSSE3'} and not avx2) 8 else 4
def sp_max = if (hasarch{'SSSE3'}) 8 else 4
if (k < sp_max) {
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
} else {
@ -688,87 +688,47 @@ def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
# If 1-byte shuffle isn't available, use 4-byte units instead
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
if (max_wv <= 4 or wv < 4) {
def ww = width{W}
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
def ne = ww/ew; def se = 64/ew
def lanes = ww > 128
def dup{v} = if (lanes) merge{v,v} else v
# 3: dedicated loop
def ww = width{W}
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
def ne = ww/ew; def se = 64/ew
def lanes = ww > 128
def dup{v} = if (lanes) merge{v,v} else v
def fixed_loop{k} = {
assert{wv == k}
while (1) {
# 01234567 to 05316427 on each byte
# e.g. 01234567 to 05316427 on each byte for k==3, ew==8
xv := get_perm_x{}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
def ib = ix // ew # byte index
def io = ew*ib + 3*ix%ew # where they are in xv
def io = ew*ib + k*ix%ew # where they are in xv
def wi = split{wl, dup{tup{255, ...ib, 255, ...se+ib}}}
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-3)}
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-k)}
# Permute and mask bytes
def step{jj, oi, ind, mask} = {
def getv = if (not lanes or jj==1) ({x}=>x)
else shuf{[4]u64, ., 4b1010 + 4b2222*(jj>1)}
def hk = (k-1) / 2
def getv = if (not lanes or jj==hk) ({x}=>x)
else sel_imm{[4]u64, ., 2*(jj>hk) + iota{4}%2}
def selx{x, i} = sel_imm{[128/ew]ty_u{ew}, getv{x}, i}
b := selx{xv, ind} & make{W, mask}
r := (b<<3) - b
r := (b<<k) - b
def selx_nz{x, i} = { def nz = i!=255; selx{x, i * nz} & W~~make{[4]i32, -nz} }
o := (if (ew==8) selx else selx_nz){xo, flat_table{max, oi, 255*(0<iota{se})}}
output{r|o}
}
each{step,
iota{3}, wi,
split{ne, replicate{3, iota{ne}}},
split{wl, each{base{2,.}, split{64, cycle{3*ww, 0==iota{3}}}}}
iota{k}, wi,
split{ne, replicate{k, iota{ne}}},
split{wl, each{base{2,.}, split{64, cycle{k*ww, 0==iota{k}}}}}
}
}
}
if (max_wv <= 4 or wv < 4) {
fixed_loop{3}
} else if (wv < 6) {
fixed_loop{5}
} else {
assert{w128{W}}
def V = re_el{u8, W}; def [vl]_ = V
def mkV = make{V, .}; def selV = sel{V, ., .}
def iV = iota{vl}
# 5, 7: precompute constants, then shared loop
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
W, V, V, W, V, V, usz }}
def set_consts{k} = {
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
def ib = ix // 8 # byte index
def io = 8*ib + k*ix%8 # where they are in xv
def wi = tup{255, ...ib, 255, ...8+ib}
xom = W**fold{|, 1<<io}
xse = mkV{join{flip{split{2, shiftright{wi, vl**255}}}}}
# Permutation to expand by k bytes, and every-k-bits mask
ind0 = mkV{iV // k}
mask0 = W~~mkV{((1<<k|1) << ((-8)*iV % k)) % 256}
def iu = iV + vl%k; def ia = iu>=vl
ind_up = mkV{iu - k*ia}
ind_inc = mkV{vl//k + ia}
mask_sh = width{V} % k
}
if (wv == 5) set_consts{5} else set_consts{7}
xv:=W**0; xo:=xv; mask:=xv; ind:=V**0 # state
q:usz = 1
while (1) {
--q; if (q == 0) { q = wv
# Load and permute bytes
xv = get_perm_x{}
# Bytes for overhang
xo = ov_bytes{(xv & xom) >> (8 - wv)}
xo = selV{xo, xse}
# Initialize state vectors
ind = ind0
mask = mask0
} else {
# Update state vectors
xo = shr{V, xo, 1}
ind = selV{ind, ind_up} + ind_inc
mask = advance_spaced_mask{wv, mask, mask_sh}
}
b := selV{xv, ind} & mask
rv:= (b<<wv) - b
o := xo & W**0xff # overhang
output{rv | o}
}
fixed_loop{7}
}
}