SSSE3 k/bool for 2≤k≤8, k≠6

This commit is contained in:
Marshall Lochbaum 2024-08-06 10:02:27 -04:00
parent a4b6d8d827
commit a439a0a430

View File

@ -291,46 +291,12 @@ exportT{'si_constrep', each{rep_const, dat_types}}
# Constant replicate on boolean
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
if (hasarch{'SSSE3'} and wv==3) {
def vl = 16; def V = [vl]u8
nv := cdiv{rlen, width{V}}
i:usz = 0; j:usz = 0
rv := *V~~r
def end = makelabel{}; while (j < nv) {
xv := load{*V~~x, i}; ++i
# 01234567 to 05316427 on each byte
def bv{bs} = fold{flat_table{+,...}, reverse{each{tup{0,.}, 1<<bs}}}
{t0, t4} := each{{is} => make{V, bv{is}}, split{4, 3*iota{8} % 8}}
m4 := V**0xf
xv = sel{V, t0, xv & m4} | sel{V, t4, V~~([8]u16~~xv>>4) & m4}
# Overhang from previous 64-bit elements
def os = 8-3 # right shift
def ix = __floor{64*slice{iota{3},1} / 3} # bits that overhang within a word
def ib = __floor{ix / 8} # byte index
def io = 8*ib + 3*ix%8 # where they are in xv
def wi = split{2, tup{255, ...ib, 255, ...8+ib}}
def W = [2]u64
xo := V~~((W~~xv & W**fold{|, 1<<io}) >> os)
xo += xo > V**0
# Permute and mask bytes
def step{jj, oi, ind, mask} = {
b := W~~(sel{V, xv, ind} & mask)
r := V~~((b<<3) - b)
o := sel{V, xo, make{V, flat_table{max, oi, 255*(0<iota{8})}}}
store{rv, j, r|o}; ++j; if (jj<2 and j==nv) goto{end}
}
def make3V{vs} = each{make{V,.}, split{vl, vs}}
each{step,
iota{3}, wi,
make3V{replicate{3, iota{vl}}},
make3V{8w2b001 << ((-8)*iota{3*vl} % 3)}
}
}
setlabel{end}
return{1}
}
def has_pdep = hasarch{'BMI2'}
if (wv > 32) return{0}
if (hasarch{'SSSE3'} and wv<=8 and wv!=6) {
rep_const_bool_ssse3{wv, x, r, rlen}
return{1}
}
if (not has_pdep and wv <= 8) return{0}
m:u64 = spaced_mask_of{wv}
xw:u64 = 0
@ -372,4 +338,142 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
}
1
}
def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
oper // ({a,b}=>floor{a/b}) infix left 40
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
def vl = 16; def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
def W = [2]u64
nv := cdiv{rlen, width{V}}
def run24{x, proc_xv, exh} = {
i:usz = 0; j:usz = 0
rv := *V~~r
def end = makelabel{}; while (j < nv) {
xv := proc_xv{load{*V~~(x+i)}}; ++i
# Store 1 or 2 result vectors
def getr = zip{exh{xv}, exh{V~~([8]u16~~xv>>4)}, .}
store{rv, j, V~~getr{0}}; ++j; if (j==nv) goto{end}
store{rv, j, V~~getr{1}}; ++j
}
setlabel{end}
}
def run{2} = {
# Expander for half byte
tabr := mkV{bv{2*iota{4}} * 2b11}
m4 := V**0xf
def exh{x} = selV{tabr, x & m4}
run24{*V~~x, {xv}=>xv, exh}
}
def run{4} = {
dup := mkV{iV//2} # Double each byte
# Expander for two bits in either bottom or next-to-bottom position
tabr := mkV{bv{tup{0,4,0,4}} * 2b1111}
m2 := mkV{2b11 << (2*(iV%2))}
def exh{x} = re_el{u16, V}~~selV{tabr, x & m2}
run24{*u64~~x, selV{., dup}, exh}
}
def run{8} = {
@for (r in *V~~r over i to nv) {
xv := load{*V~~(*u16~~x + i)}
xe := selV{xv, mkV{iV//8}}
r = (xe & mkV{1 << (iV % 8)}) > V**0
}
}
# For odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
# - mask out those bits and spread over [ k*i, k*(i+1) )
# - ...except where it crosses words; handle this overhang separately
# 3 is handled separately and unrolled; 5 and 7 share a loop
def run{3} = {
i:usz = 0; j:usz = 0
rv := *V~~r
def end = makelabel{}; while (j < nv) {
xv := load{*V~~x, i}; ++i
# 01234567 to 05316427 on each byte
{t0, t4} := each{{is} => mkV{bv{is}}, split{4, 3*iota{8} % 8}}
m4 := V**0xf
xv = selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
def ib = ix // 8 # byte index
def io = 8*ib + 3*ix%8 # where they are in xv
def wi = split{2, tup{255, ...ib, 255, ...8+ib}}
xo := V~~((W~~xv & W**fold{|, 1<<io}) >> (8-3))
xo += xo > V**0
# Permute and mask bytes
def step{jj, oi, ind, mask} = {
b := W~~(selV{xv, ind} & mask)
r := V~~((b<<3) - b)
o := selV{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
store{rv, j, r|o}; ++j; if (jj<2 and j==nv) goto{end}
}
def make3V{vs} = each{make{V,.}, split{vl, vs}}
each{step,
iota{3}, wi,
make3V{replicate{3, iota{vl}}},
make3V{8w2b001 << ((-8)*iota{3*vl} % 3)}
}
}
setlabel{end}
}
def run{{5,7}} = {
{t0, t4, xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
V, V, W, V, V, V, V, V, usz }}
def set_consts{k} = {
# Within-byte transformation
tup{t0, t4} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
def ib = ix // 8 # byte index
def io = 8*ib + k*ix%8 # where they are in xv
def wi = tup{255, ...ib, 255, ...8+ib}
xom = W**fold{|, 1<<io}
xse = mkV{join{flip{split{2, shiftright{wi, vl**255}}}}}
# Permutation to expand by k bytes, and every-k-bits mask
ind0 = mkV{iV // k}
mask0 = mkV{((1<<k|1) << ((-8)*iV % k)) % 256}
def iu = iV + vl%k; def ia = iu>=vl
ind_up = mkV{iu - k*ia}
ind_inc = mkV{vl//k + ia}
mask_sh = width{V} % k
}
if (wv == 5) set_consts{5} else set_consts{7}
xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state
i:usz = 0; q:usz = 1
@for (r in *V~~r over j to nv) {
--q; if (q == 0) { q = wv
# Load and permute bytes
xv = load{*V~~x, i}; ++i
m4 := V**0xf
xv = selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
# Bytes for overhang
xo = V~~((W~~xv & xom) >> (8 - wv))
xo += xo > V**0
xo = selV{xo, xse}
# Initialize state vectors
ind = ind0
mask = mask0
} else {
# Update state vectors
xo = shr{V, xo, 1}
ind = selV{ind, ind_up} + ind_inc
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
}
b := W~~(selV{xv, ind} & mask)
r = V~~((b<<wv) - b)
r |= xo & mkV{255 * (iV%8 == 0)} # overhang
}
}
if (wv==2) run{2}
else if (wv==3) run{3}
else if (wv==4) run{4}
else if (wv==8) run{8}
else run{tup{5,7}}
}
export{'si_constrep_bool', rep_const_bool{}}