Move to dedicated k/bool loops for k=5,7 and at least SSSE3
This commit is contained in:
parent
b2758d355c
commit
b801c7c186
@ -586,7 +586,7 @@ def rep_const_bool_odd{k, x, r, nw} = {
|
||||
swap_data := load{swtab, k>>1}
|
||||
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
|
||||
|
||||
def sp_max = if (hasarch{'SSSE3'} and not avx2) 8 else 4
|
||||
def sp_max = if (hasarch{'SSSE3'}) 8 else 4
|
||||
if (k < sp_max) {
|
||||
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
|
||||
} else {
|
||||
@ -688,87 +688,47 @@ def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
|
||||
# If 1-byte shuffle isn't available, use 4-byte units instead
|
||||
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
||||
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
|
||||
if (max_wv <= 4 or wv < 4) {
|
||||
def ww = width{W}
|
||||
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
|
||||
def ne = ww/ew; def se = 64/ew
|
||||
def lanes = ww > 128
|
||||
def dup{v} = if (lanes) merge{v,v} else v
|
||||
# 3: dedicated loop
|
||||
def ww = width{W}
|
||||
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
|
||||
def ne = ww/ew; def se = 64/ew
|
||||
def lanes = ww > 128
|
||||
def dup{v} = if (lanes) merge{v,v} else v
|
||||
def fixed_loop{k} = {
|
||||
assert{wv == k}
|
||||
while (1) {
|
||||
# 01234567 to 05316427 on each byte
|
||||
# e.g. 01234567 to 05316427 on each byte for k==3, ew==8
|
||||
xv := get_perm_x{}
|
||||
# Overhang from previous 64-bit elements
|
||||
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
|
||||
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
||||
def ib = ix // ew # byte index
|
||||
def io = ew*ib + 3*ix%ew # where they are in xv
|
||||
def io = ew*ib + k*ix%ew # where they are in xv
|
||||
def wi = split{wl, dup{tup{255, ...ib, 255, ...se+ib}}}
|
||||
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-3)}
|
||||
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-k)}
|
||||
# Permute and mask bytes
|
||||
def step{jj, oi, ind, mask} = {
|
||||
def getv = if (not lanes or jj==1) ({x}=>x)
|
||||
else shuf{[4]u64, ., 4b1010 + 4b2222*(jj>1)}
|
||||
def hk = (k-1) / 2
|
||||
def getv = if (not lanes or jj==hk) ({x}=>x)
|
||||
else sel_imm{[4]u64, ., 2*(jj>hk) + iota{4}%2}
|
||||
def selx{x, i} = sel_imm{[128/ew]ty_u{ew}, getv{x}, i}
|
||||
b := selx{xv, ind} & make{W, mask}
|
||||
r := (b<<3) - b
|
||||
r := (b<<k) - b
|
||||
def selx_nz{x, i} = { def nz = i!=255; selx{x, i * nz} & W~~make{[4]i32, -nz} }
|
||||
o := (if (ew==8) selx else selx_nz){xo, flat_table{max, oi, 255*(0<iota{se})}}
|
||||
output{r|o}
|
||||
}
|
||||
each{step,
|
||||
iota{3}, wi,
|
||||
split{ne, replicate{3, iota{ne}}},
|
||||
split{wl, each{base{2,.}, split{64, cycle{3*ww, 0==iota{3}}}}}
|
||||
iota{k}, wi,
|
||||
split{ne, replicate{k, iota{ne}}},
|
||||
split{wl, each{base{2,.}, split{64, cycle{k*ww, 0==iota{k}}}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_wv <= 4 or wv < 4) {
|
||||
fixed_loop{3}
|
||||
} else if (wv < 6) {
|
||||
fixed_loop{5}
|
||||
} else {
|
||||
assert{w128{W}}
|
||||
def V = re_el{u8, W}; def [vl]_ = V
|
||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||
def iV = iota{vl}
|
||||
# 5, 7: precompute constants, then shared loop
|
||||
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
|
||||
W, V, V, W, V, V, usz }}
|
||||
def set_consts{k} = {
|
||||
# Overhang from previous 64-bit elements
|
||||
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
||||
def ib = ix // 8 # byte index
|
||||
def io = 8*ib + k*ix%8 # where they are in xv
|
||||
def wi = tup{255, ...ib, 255, ...8+ib}
|
||||
xom = W**fold{|, 1<<io}
|
||||
xse = mkV{join{flip{split{2, shiftright{wi, vl**255}}}}}
|
||||
# Permutation to expand by k bytes, and every-k-bits mask
|
||||
ind0 = mkV{iV // k}
|
||||
mask0 = W~~mkV{((1<<k|1) << ((-8)*iV % k)) % 256}
|
||||
def iu = iV + vl%k; def ia = iu>=vl
|
||||
ind_up = mkV{iu - k*ia}
|
||||
ind_inc = mkV{vl//k + ia}
|
||||
mask_sh = width{V} % k
|
||||
}
|
||||
if (wv == 5) set_consts{5} else set_consts{7}
|
||||
xv:=W**0; xo:=xv; mask:=xv; ind:=V**0 # state
|
||||
q:usz = 1
|
||||
while (1) {
|
||||
--q; if (q == 0) { q = wv
|
||||
# Load and permute bytes
|
||||
xv = get_perm_x{}
|
||||
# Bytes for overhang
|
||||
xo = ov_bytes{(xv & xom) >> (8 - wv)}
|
||||
xo = selV{xo, xse}
|
||||
# Initialize state vectors
|
||||
ind = ind0
|
||||
mask = mask0
|
||||
} else {
|
||||
# Update state vectors
|
||||
xo = shr{V, xo, 1}
|
||||
ind = selV{ind, ind_up} + ind_inc
|
||||
mask = advance_spaced_mask{wv, mask, mask_sh}
|
||||
}
|
||||
b := selV{xv, ind} & mask
|
||||
rv:= (b<<wv) - b
|
||||
o := xo & W**0xff # overhang
|
||||
output{rv | o}
|
||||
}
|
||||
fixed_loop{7}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user