SSE2 odd/bool implementation using [4]u32 shuffles

This commit is contained in:
Marshall Lochbaum 2024-08-13 11:42:12 -04:00
parent 36e9ca5814
commit ee27717c97

View File

@ -347,23 +347,38 @@ def get_boolvec_writer{V, r, nw} = {
def vwords = width{V}/64
rv := *V~~r
re := rv + nw / vwords
rs := rv + cdiv{nw, vwords}
# Avoid reading/processing an extra word past the end of input
def done = makelabel{}
def check_done{} = if (rv == rs) goto{done}
# If the last result is partial, jump to flush to write it
last_res:V = V**0
def end = makelabel{}
def l_flush = makelabel{}
def output{v:(V)} = {
last_res = v
if (rv==re) goto{end}
if (rv == re) goto{l_flush}
store{rv, 0, v}; ++rv
}
def flush{} = {
setlabel{end}
setlabel{l_flush}
q := nw & (vwords-1)
if (q != 0) homMaskStoreF{rv, V~~maskOf{re_el{u64,V}, q}, last_res}
setlabel{done}
}
tup{output, flush}
tup{output, check_done, flush}
}
def get_boolvec_writer{T=(u64), r:*T, nw} = {
def done = makelabel{}
j:usz = 0
def output{rw} = {
store{r, j, rw}
++j; if (j==nw) goto{done}
}
def flush{} = setlabel{done}
tup{output, {}=>{}, flush}
}
def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = {
oper // ({a,b}=>floor{a/b}) infix left 40
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16
def V = [vl]u8
@ -372,7 +387,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = {
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
def id{xv} = xv
def {output, flush} = get_boolvec_writer{V, r, nw}
def {output, done, flush} = get_boolvec_writer{V, r, nw}
def run24{x, proc_xv, exh} = {
i:usz = 0; while (1) {
@ -409,6 +424,19 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = {
flush{}
}
def sel_imm{V=[4]T, x, {...inds} if hasarch{'SSE2'} and (T==u32 or T==u64)} = {
assert{length{inds} == 4}
shuf{V, x, base{4, inds}}
}
def sel_imm{V=([16]u8), x:X, {...inds}} = {
def I = re_el{u8,X}; def [n]_ = I
def l = length{inds}
assert{l==16 or l==n}
sel{V, x, make{I, cycle{n, inds}}}
}
def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh
# Data for the permutation that sends bit i to k*i % width{T}
def modperm_dat{T, k} = {
def w = width{T}
@ -419,45 +447,119 @@ def modperm_dat{T, k} = {
{_} => T~~base{2, bits}
}
}
def modperm_step{x, l, m} = {
# Permutation step evaluators
# shift takes a top-half mask; others take both-halves
def modperm_shift_step{x, l, m} = {
def d = (x ^ x<<l) & m
x ^ (d | d>>l)
}
def modperm_step{x:T, l=(width{T}/2), m} = {
def mm = m | m>>l
(x &~ mm) | ((x<<l | x>>l) & mm) # rotate
def modperm_rot_step{x:T, l=(width{T}/2), m} = {
(x &~ m) | ((x<<l | x>>l) & m)
}
def modperm_step{x:T=[_](u8), l, m} = {
def W = re_el{u64, T}
T~~modperm_step{W~~x, l, W~~m}
def modperm_shuf_step{x:V=[_]T, l, m if l%8==0} = {
(x &~ m) | (swap_elts{x, l/8} & m)
}
def swap_elts{x:V=[vl](u8), el_bytes} = { # Reverse each pair of elements
def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) }
# Reverse each pair of elements
def swap_elts{x:V=[_]_, el_bytes} = {
def selx{n, wd} = {
def l = el_bytes/wd
def i = iota{n}
def rev = i + (l - 2*(i&l))
sel_imm{[n]ty_u{wd*8}, x, rev}
}
if (el_bytes >= 4) {
def wd = max{4, el_bytes/2}
shuf{[4]ty_u{wd*8}, x, base{4, swi{4, el_bytes/wd}}}
selx{4, max{4, el_bytes/2}}
} else if (hasarch{'SSSE3'}) {
selx{16, 1}
} else {
def i = swi{16, el_bytes}
sel{[16]u8, x, make{V, cycle{vl, i}}}
def sh = el_bytes*8
xe := re_el{ty_u{2*sh}, V} ~~ x
V~~(xe<<sh | xe>>sh)
}
}
def modperm_step{x, l, m:V=[_]T if l%8==0} = {
(x & m) | (swap_elts{x, l/8} &~ m)
# Extract swap functions from modperm_dat
def extract_modperm_mask{full:W, lane, l} = {
def f = l == 128 # Use full width
def w = if (l<32) 8 else if (f) 64 else 32 # Shuffle element width
def n = (if (f) 256 else 128) / w # and number of elements
def o = l / w # Extract [o,2*o)
def i = o + iota{n}%o
sel_imm{[n]ty_u{w}, if (f) full else lane, i}
}
def modperm_get_byteperm{sw_bytes:V=[_](u8) if hasarch{'SSSE3'}} = {
def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s}
m4 := V**0xf
t0 := fold{{v,a}=>modperm_step{v,...a}, make{V,iota{vcount{V}}%16}, tup{
tup{4, sw_bytes &~ m4},
tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}}
}}
t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4}
def selI = sel{[16]u8, ., .}
{xv} => selI{t0, xv & m4} | selI{t4, shW{>>, xv, 4} & m4}
def proc_mod_dat{swap_data:W} = {
def ww = width{W}; def avx2 = ww==256
def on_len_range{get_swap, lo, hi} = {
def lens = reverse{1 << slice{iota{lb{hi}}, lb{lo}}}
fold{{x,s}=>s{x}, ., each{get_swap, lens}}
}
swap_lane := if (avx2) shuf{W, swap_data, 4b1010} else swap_data
# Transform width-w units with shifts only
def get_shiftperm{data, w} = {
dat := data
bot:W = W**base{2, cycle{64, replicate{w, tup{1,0}}}}
def gsw{l} = {
bot ^= bot << l # Low l bits out of every 2*l
sm := dat &~ bot
dat &= bot; dat |= dat<<l
if (l<32) modperm_shift_step{., l, sm}
else modperm_rot_step {., l, sm|sm>>l}
}
on_len_range{gsw, 2, w}
}
# Within-byte transformation
def get_byteperm{} = {
def V = re_el{u8, W}
def sw_bytes = sel{[16]u8, swap_lane, V**0}
m4 := W~~V**0xf
t0 := fold{{v,a}=>modperm_shift_step{v,...a}, W~~make{V,iota{vcount{V}}%16}, tup{
tup{4, sw_bytes &~ m4},
tup{2, ({v} => v | v<<4){sw_bytes & W~~V**0xc}}
}}
t4 := (t0&m4)<<4 | (t0&~m4)>>4
def selI{v,i} = sel{[16]u8, v, V~~i}
{xv} => selI{t0, xv & m4} | selI{t4, (xv>>4) & m4}
}
def {partwidth, partperm} = {
def sh{w, d} = tup{w, get_shiftperm{d, w}}
if (ww==64) sh{64, swap_data}
else if (not hasarch{'SSSE3'}) sh{32, shuf{[4]u32, swap_data, 4b0000}}
else tup{8, get_byteperm{}}
}
# Fill in higher steps
def get_mod_permuter{} = {
def get_swap{l} = {
def mask = extract_modperm_mask{swap_data, swap_lane, l}
modperm_shuf_step{., l, mask}
}
def swap_x = on_len_range{get_swap, partwidth, ww}
{x} => partperm{swap_x{x}}
}
tup{partperm, get_mod_permuter}
}
def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}}
def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh
def rep_const_bool_odd{k, x, r, nw} = {
def avx2 = hasarch{'AVX2'}
def W = if (hasarch{'SSE2'}) [if (avx2) 4 else 2]u64 else u64
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
xp := *W~~x
def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} }
# Modular permutation: small-k cases may use a limited permutation
# on bytes or 32-bit ints; general case uses the whole thing
swtab:*W = each{modperm_dat{W, .}, 1+2*iota{32}}
swap_data := load{swtab, k>>1}
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
def sp_max = if (hasarch{'SSSE3'} and not avx2) 8 else 4
if (k < sp_max) {
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
} else {
def get_swap_x = getter{get_full_permute{}}
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}}
}
flush{}
}
# General-case loop for odd replication factors
def rep_const_bool_odd_mask4{
@ -465,8 +567,6 @@ def rep_const_bool_odd_mask4{
k, # replication factor
get_modperm_x, # permuted input
output, n, # output, number of writes
mask:(u64), # starting mask
mask_sh # single iteration shift
} = {
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
def scal = ifvec{{v} => M**v}
@ -480,6 +580,8 @@ def rep_const_bool_odd_mask4{
def ss = s+s
tup{mn, ss - (k &- (ss>k))}
}
# Starting word mask and single-word shift
{mask, mask_sh} := unaligned_spaced_mask_mod{k}
# Mask and shift for one iteration
def fillmask{T} = match (T) {
{(u64)} => tup{mask, mask_sh}
@ -511,7 +613,7 @@ def rep_const_bool_odd_mask4{
while (1) {
x:M = get_modperm_x{}
def vrot1 = ifvec{{x} => if (w128{M}) vshl{x, x, vcount{type{x}}-1}
def vrot1 = ifvec{{x} => if (w128{M}) shuf{[4]u32, x, 4b1032}
else shuf{M, x, 4b2103}}
xo := x<<k | vrot1{x>>(64-k)}
# Write result word given starting bits
@ -532,131 +634,66 @@ def rep_const_bool_odd_mask4{
}
}
def rep_const_bool_odd{k, xp, rp, nw} = {
# Every-k-bits mask
{mask, mask_sh} := unaligned_spaced_mask_mod{k}
# Transform sending bit i to k*i % 64 by pairwise swaps
# Swap data goes in a pre-computed table
swtab:*u64 = each{modperm_dat{u64, .}, 1+2*iota{32}}
def swap_lens = reverse{2 << iota{5}}
swap_data := load{swtab, k>>1}
swsel:u64 = ~u64~~0
def gsw{l} = {
swsel ^= swsel << l # Low l bits out of every 2*l
sm := swap_data &~ swsel
swap_data &= swsel; swap_data |= swap_data<<l
sm
}
swap_masks := each{gsw, swap_lens}
i:usz = 0
# Load x, send bit i to position k*i % 64
def swap_x = fold_multi{modperm_step, ., swap_lens, swap_masks}
def get_swap_x{} = { x := swap_x{load{xp, i}}; ++i; x }
# Output
j:usz = 0
def output{rw} = {
store{rp, j, rw}
++j; if (j==nw) return{1}
}
# Dedicated loop for 3, shared for other factors
if (k == 3) while (1) {
def unrolled_iter{k} = {
def reps{f, n, init} = scan{{x,_}=>f{x}, n**init}
def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask}
def step{x}{o, m, n_over} = {
b := x & m
def os = (64-k) + 1 + iota{n_over}
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
b
}
fold_multi{step{get_swap_x{}}, 0, masks, (-iota{k}*64)%k}
}
unrolled_iter{3}
} else {
rep_const_bool_odd_mask4{u64, k, get_swap_x, output, nw, mask, mask_sh}
def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
def k = 3
def step{x}{o, n_over} = {
b := x & base{2, cycle{64, n_over == iota{k}}}
def os = (64-k) + 1 + iota{n_over}
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
b
}
while (1) fold{step{get_swap_x{}}, 0, (-iota{k}*64)%k}
}
# For odd numbers:
# For small odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
# - mask out those bits and spread over [ k*i, k*(i+1) )
# - ...except where it crosses words; handle this overhang separately
def rep_const_bool_odd{k, x, r, nw if hasarch{'SSSE3'}} = {
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16; def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{[16]u8, ., .}
def W = re_el{u64, V}
def {output, flush} = get_boolvec_writer{W, r, nw}
# Swap data goes in a pre-computed table
swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}}
swap_data := load{swtab, k>>1}
swap_lane := if (avx2) shuf{[4]u64, swap_data, 4b1010} else swap_data
# Within-byte transformation
def perm_x = modperm_get_byteperm{selV{swap_lane, V**0}}
def sp_max = if (avx2) 4 else 8
if (k < sp_max) {
rep_const_bool_odd_special{V, sp_max, k, x, perm_x, {v}=>output{W~~v}}
} else {
{m, d} := unaligned_spaced_mask_mod{k}
# General case
def get_mask{l} = selV{~swap_lane, mkV{l+iV%l}}
def get_mask{16} = shuf{[4]u64, ~swap_data, 4b3232}
def swap_lens = reverse{1 << iota{4 + avx2}}
swap_masks := each{get_mask, swap_lens}
def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks}
i:usz = 0
def get_swap_x{} = { xv := W~~perm_x{swap_x{load{*V~~x, i}}}; ++i; xv }
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}, m, d}
}
flush{}
}
def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = {
oper // ({a,b}=>floor{a/b}) infix left 40
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{[16]u8, ., .}
def W = re_el{u64, V}
# If 1-byte shuffle isn't available, use 4-byte units instead
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
if (max_wv <= 4 or wv < 4) {
def ll = 16; def dup{v} = if (vl>ll) merge{v,v} else v
def ww = width{W}
def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements
def ne = ww/ew; def se = 64/ew
def lanes = ww > 128
def dup{v} = if (lanes) merge{v,v} else v
# 3: dedicated loop
i:usz = 0; while (1) {
while (1) {
# 01234567 to 05316427 on each byte
xv := perm_x{load{*V~~x, i}}; ++i
xv := get_perm_x{}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
def ib = ix // 8 # byte index
def io = 8*ib + 3*ix%8 # where they are in xv
def wi = split{vl/8, dup{tup{255, ...ib, 255, ...8+ib}}}
xo := V~~((W~~xv & W**fold{|, 1<<io}) >> (8-3))
xo += xo > V**0
def ib = ix // ew # byte index
def io = ew*ib + 3*ix%ew # where they are in xv
def wi = split{wl, dup{tup{255, ...ib, 255, ...se+ib}}}
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-3)}
# Permute and mask bytes
def step{jj, oi, ind, mask} = {
def getv = if (vl==ll or jj==1) ({x}=>x)
def getv = if (not lanes or jj==1) ({x}=>x)
else shuf{[4]u64, ., 4b1010 + 4b2222*(jj>1)}
def selx{x, i} = sel{[ll]u8, getv{x}, i}
b := W~~(selx{xv, ind} & mask)
r := V~~((b<<3) - b)
o := selx{xo, mkV{flat_table{max, oi, 255*(0<iota{8})}}}
def selx{x, i} = sel_imm{[128/ew]ty_u{ew}, getv{x}, i}
b := selx{xv, ind} & make{W, mask}
r := (b<<3) - b
def selx_nz{x, i} = { def nz = i!=255; selx{x, i * nz} & W~~make{[4]i32, -nz} }
o := (if (ew==8) selx else selx_nz){xo, flat_table{max, oi, 255*(0<iota{se})}}
output{r|o}
}
def make3V{vs} = each{make{V,.}, split{vl, dup{vs}}}
each{step,
iota{3}, wi,
make3V{replicate{3, iota{ll}}},
make3V{8w2b001 << ((-8)*iota{3*ll} % 3)}
split{ne, replicate{3, iota{ne}}},
split{wl, each{base{2,.}, split{64, cycle{3*ww, 0==iota{3}}}}}
}
}
} else {
assert{w128{V}}
def selV = sel{V, ., .}
assert{w128{W}}
def V = re_el{u8, W}; def [vl]_ = V
def mkV = make{V, .}; def selV = sel{V, ., .}
def iV = iota{vl}
# 5, 7: precompute constants, then shared loop
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
W, V, V, V, V, V, usz }}
W, V, V, W, V, V, usz }}
def set_consts{k} = {
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
@ -667,22 +704,21 @@ def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = {
xse = mkV{join{flip{split{2, shiftright{wi, vl**255}}}}}
# Permutation to expand by k bytes, and every-k-bits mask
ind0 = mkV{iV // k}
mask0 = mkV{((1<<k|1) << ((-8)*iV % k)) % 256}
mask0 = W~~mkV{((1<<k|1) << ((-8)*iV % k)) % 256}
def iu = iV + vl%k; def ia = iu>=vl
ind_up = mkV{iu - k*ia}
ind_inc = mkV{vl//k + ia}
mask_sh = width{V} % k
}
if (wv == 5) set_consts{5} else set_consts{7}
xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state
i:usz = 0; q:usz = 1
xv:=W**0; xo:=xv; mask:=xv; ind:=V**0 # state
q:usz = 1
while (1) {
--q; if (q == 0) { q = wv
# Load and permute bytes
xv = perm_x{load{*V~~x, i}}; ++i
xv = get_perm_x{}
# Bytes for overhang
xo = V~~((W~~xv & xom) >> (8 - wv))
xo += xo > V**0
xo = ov_bytes{(xv & xom) >> (8 - wv)}
xo = selV{xo, xse}
# Initialize state vectors
ind = ind0
@ -691,11 +727,11 @@ def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = {
# Update state vectors
xo = shr{V, xo, 1}
ind = selV{ind, ind_up} + ind_inc
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
mask = advance_spaced_mask{wv, mask, mask_sh}
}
b := W~~(selV{xv, ind} & mask)
rv:= V~~((b<<wv) - b)
o := xo & mkV{255 * (iV%8 == 0)} # overhang
b := selV{xv, ind} & mask
rv:= (b<<wv) - b
o := xo & W**0xff # overhang
output{rv | o}
}
}