Combine generic and SSSE3 general-case odd/bool, so SSSE3 uses mask unpacking
This commit is contained in:
parent
3e5dbdbf8d
commit
2be022921e
@ -300,7 +300,8 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
} else {
|
||||
tlen := rlen>>p
|
||||
wq := usz~~1<<p
|
||||
if (p == 1 or (p == 2 and wv>=52)) { # Expanding odd second is faster
|
||||
if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (not hasarch{'AVX2'} and p == 1 and wv>=24)) {
|
||||
# Expanding odd second is faster
|
||||
tlen = rlen / wf
|
||||
t:=wf; wf=wq; wq=t
|
||||
}
|
||||
@ -309,14 +310,12 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
rep_const_bool{}(wq, t, r, rlen)
|
||||
}
|
||||
} else {
|
||||
m:u64 = spaced_mask_of{wv}
|
||||
d := cast_i{usz, popc{m}} # == 64/wv
|
||||
rep_const_bool_generic_odd{wv, x, r, nw, m, d}
|
||||
rep_const_bool_odd{wv, x, r, nw}
|
||||
}
|
||||
1
|
||||
}
|
||||
|
||||
def rep_const_bool_div8{wv, x, r, nw} = {
|
||||
def rep_const_bool_div8{wv, x, r, nw} = { # wv in 2,4,8
|
||||
def run{k} = {
|
||||
# 2 -> 64w0x33, 12 -> 64w0x000f, etc.
|
||||
def getm{sh} = base{2, iota{64}&sh == 0}
|
||||
@ -344,133 +343,8 @@ def rep_const_bool_div8{wv, x, r, nw} = {
|
||||
cases{2}
|
||||
}
|
||||
|
||||
def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
|
||||
# Every-k-bits mask
|
||||
mask_sh := cast_i{usz, ctz{m}} # == 64%k
|
||||
mask := m<<(k-mask_sh) | 1
|
||||
|
||||
# Transform sending bit i to k*i % 64 by pairwise swaps
|
||||
# Swap data goes in a pre-computed table
|
||||
def swdat{lw, k} = {
|
||||
def i = iota{lw}
|
||||
def bits = (k*iota{1<<lw} >> merge{0, replicate{1<<i, i}}) & 1
|
||||
~u64~~base{2, bits}
|
||||
}
|
||||
swtab:*u64 = each{swdat{6,.}, 1+2*iota{32}}
|
||||
def swap_lens = reverse{2 << iota{5}}
|
||||
swap_data := load{swtab, k>>1}
|
||||
swsel:u64 = ~u64~~0
|
||||
def gsw{l} = {
|
||||
swsel ^= swsel << l # Low l bits out of every 2*l
|
||||
sm := swap_data &~ swsel
|
||||
swap_data &= swsel; swap_data |= swap_data<<l
|
||||
sm
|
||||
}
|
||||
swap_masks := each{gsw, swap_lens}
|
||||
i:usz = 0
|
||||
def get_swap_x{} = {
|
||||
# Load x, send bit i to position k*i % 64
|
||||
x := load{xp, i}; ++i
|
||||
def swap_step{l, m} = {
|
||||
xx := (x ^ x<<l) & m
|
||||
x ^= xx | (xx>>l)
|
||||
}
|
||||
def swap_step{l==32, m} = { # Use rotate
|
||||
mm := m | m>>l
|
||||
x = (x &~ mm) | ((x<<l | x>>l) & mm)
|
||||
}
|
||||
each{swap_step, swap_lens, swap_masks}
|
||||
x
|
||||
}
|
||||
|
||||
# Output
|
||||
j:usz = 0
|
||||
def output{rw} = {
|
||||
store{rp, j, rw}
|
||||
++j; if (j==nw) return{1}
|
||||
}
|
||||
o:u64 = 0 # carry
|
||||
# Dedicated loop for 3, shared for other factors
|
||||
if (k == 3) {
|
||||
while (1) {
|
||||
x := get_swap_x{}
|
||||
@unroll (jj to 3) {
|
||||
b := x & mask
|
||||
mask = (mask << (3 - 64%3)) | (mask >> (64%3))
|
||||
def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}}
|
||||
output{fold{|, (b<<3) - b, each{>>{o,.}, os}}}
|
||||
o = b
|
||||
}
|
||||
}
|
||||
} else {
|
||||
# Fundamental operation: shifts act as order-k cyclic group on masks
|
||||
def advance{m, sh} = m<<(k-sh) | m>>sh
|
||||
sm0 := mask # starting mask
|
||||
s1 := mask_sh # single iteration shift
|
||||
# Get cumulative mask for 4 iterations, and shift to advance 4
|
||||
def double{{mc, s}} = {
|
||||
def ss = s+s
|
||||
tup{mc|advance{mc,s}, ss - (k &- (ss>k))}
|
||||
}
|
||||
{mc4, s4} := double{double{tup{sm0, s1}}}
|
||||
# Submasks pick one mask out of a combination of 4
|
||||
def or_adv{m, s} = { m |= advance{m,s} }
|
||||
@for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
|
||||
submasks := scan{advance, tup{sm0, ...3**s1}}
|
||||
mask_tail := advance{sm0, s4} &~ sm0
|
||||
# Mask out carry bit
|
||||
mr := u64~~1<<k - 1
|
||||
while (1) {
|
||||
x := get_swap_x{}
|
||||
os:=o; xo:=x<<k|x>>(64-k); o=xo&1; xo=(xo&~1)|os
|
||||
mask = mc4
|
||||
# Write result word given starting bits
|
||||
def step{b, c} = output{c - b - promote{u64, c&mr != 0}}
|
||||
def step{b, c, m} = step{b&m, c&m}
|
||||
# Fast unrolled iterations
|
||||
@for (k/4) {
|
||||
each{step{x & mask, xo & mask, .}, submasks}
|
||||
mask = advance{mask, s4}
|
||||
}
|
||||
# Single-step for tail
|
||||
mask = mask_tail
|
||||
@for (k%4) {
|
||||
step{x, xo, mask}
|
||||
mask = advance{mask, s1}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||
if (wv > 32) return{0}
|
||||
if (wv&1 == 0) {
|
||||
p := ctz{wv | 8} # Power of two for second replicate
|
||||
if (wv>>p == 1) {
|
||||
rep_const_bool_div8{wv, x, r, rlen}
|
||||
} else {
|
||||
tlen := rlen>>p
|
||||
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||
rep_const_bool{}(wv >>p, x, t, tlen)
|
||||
rep_const_bool{}(usz~~1<<p, t, r, rlen)
|
||||
}
|
||||
} else {
|
||||
rep_const_bool_ssse3_odd{wv, x, r, rlen}
|
||||
}
|
||||
1
|
||||
}
|
||||
|
||||
# Generalized flat transpose of iota{1<<length{bs}}
|
||||
# select{tr_iota{bs}, x} sends bit i of x to position select{bs, i}
|
||||
def tr_iota{...bs} = {
|
||||
def axes = each{tup{0,.}, 1<<bs}
|
||||
fold{flat_table{|,...}, reverse{axes}}
|
||||
}
|
||||
def tr_iota{{...bs}} = tr_iota{...bs}
|
||||
|
||||
def get_boolvec_writer{V, r, rlen} = {
|
||||
def get_boolvec_writer{V, r, nw} = {
|
||||
def vwords = width{V}/64
|
||||
nw := cdiv{rlen, 64}
|
||||
rv := *V~~r
|
||||
re := rv + nw / vwords
|
||||
last_res:V = V**0
|
||||
@ -488,7 +362,7 @@ def get_boolvec_writer{V, r, rlen} = {
|
||||
tup{output, flush}
|
||||
}
|
||||
|
||||
def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8
|
||||
def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = {
|
||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
def vl = if (avx2) 32 else 16
|
||||
@ -498,7 +372,7 @@ def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8
|
||||
def selH = sel{[16]u8, ., .}
|
||||
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
|
||||
def id{xv} = xv
|
||||
def {output, flush} = get_boolvec_writer{V, r, rlen}
|
||||
def {output, flush} = get_boolvec_writer{V, r, nw}
|
||||
|
||||
def run24{x, proc_xv, exh} = {
|
||||
i:usz = 0; while (1) {
|
||||
@ -535,29 +409,189 @@ def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8
|
||||
flush{}
|
||||
}
|
||||
|
||||
# Data for the permutation that sends bit i to k*i % width{T}
|
||||
def modperm_dat{T, k} = {
|
||||
def w = width{T}
|
||||
def lw= lb{w}
|
||||
def i = iota{lw}
|
||||
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
|
||||
match (T) {
|
||||
{[_]E} => make{T, each{base{2,.}, split{width{E}, bits}}}
|
||||
{_} => T~~base{2, bits}
|
||||
}
|
||||
}
|
||||
def modperm_step{x, l, m} = {
|
||||
def d = (x ^ x<<l) & m
|
||||
x ^ (d | d>>l)
|
||||
}
|
||||
def modperm_step{x:T, l=(width{T}/2), m} = {
|
||||
def mm = m | m>>l
|
||||
(x &~ mm) | ((x<<l | x>>l) & mm) # rotate
|
||||
}
|
||||
def modperm_step{x:T=[_](u8), l, m} = {
|
||||
def W = re_el{u64, T}
|
||||
T~~modperm_step{W~~x, l, W~~m}
|
||||
}
|
||||
def swap_elts{x:V, el_bytes if w128{V}} = {
|
||||
def n = 16; def I = [n]u8
|
||||
def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) }
|
||||
if (el_bytes >= 4) shuf{[4]u32, x, base{4, swi{4, el_bytes/4}}}
|
||||
else sel{I, x, make{I, swi{16, el_bytes}}}
|
||||
}
|
||||
def modperm_step{x, l, m:V=[_]T if l%8==0} = {
|
||||
(x & m) | (swap_elts{x, l/8} &~ m)
|
||||
}
|
||||
def modperm_get_byteperm{sw_bytes:V=([16]u8) if hasarch{'SSSE3'}} = {
|
||||
def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s}
|
||||
m4 := V**0xf
|
||||
t0 := fold{{v,a}=>modperm_step{v,...a}, iota{V}, tup{
|
||||
tup{4, sw_bytes &~ m4},
|
||||
tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}}
|
||||
}}
|
||||
t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4}
|
||||
{xv} => sel{V, t0, xv & m4} | sel{V, t4, shW{>>, xv, 4} & m4}
|
||||
}
|
||||
def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}}
|
||||
|
||||
def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh
|
||||
|
||||
# General-case loop for odd replication factors
|
||||
def rep_const_bool_odd_mask4{
|
||||
M, # read/write type
|
||||
k, # replication factor
|
||||
get_modperm_x, # permuted input
|
||||
output, n, # output, number of writes
|
||||
mask:(u64), # starting mask
|
||||
mask_sh # single iteration shift
|
||||
} = {
|
||||
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
|
||||
def scal = ifvec{{v} => M**v}
|
||||
|
||||
# Fundamental operation: shifts act as order-k cyclic group on masks
|
||||
def advance{m, sh} = advance_spaced_mask{k, m, sh}
|
||||
# Double a cumulative mask, shift combination
|
||||
# If s advances l iterations, mc combines iterations iota{l}
|
||||
def double_gen{comb}{{mc, s}} = {
|
||||
def mn = comb{mc, advance{mc,s}}
|
||||
def ss = s+s
|
||||
tup{mn, ss - (k &- (ss>k))}
|
||||
}
|
||||
# Mask and shift for one iteration
|
||||
{sm0, s1} := ifvec{double_gen{make{M,...}}}{tup{mask, mask_sh}}
|
||||
# Combined mask for 4 iterations, and shift to advance 4
|
||||
def double = double_gen{|}
|
||||
{mc4, s4} := double{double{tup{sm0, s1}}}
|
||||
# Submasks pick one mask out of a combination of 4
|
||||
def or_adv{m, s} = { m |= advance{m,s} }
|
||||
@for (min{k/4 - 1, n/4}) or_adv{sm0,s4}
|
||||
submasks := scan{advance, tup{sm0, ...3**s1}}
|
||||
mask_tail := advance{sm0, s4} &~ sm0
|
||||
|
||||
# Carry: shifting and word-crossing is done on the initial permuted x
|
||||
o:M = scal{0} # Carry for x
|
||||
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
||||
def sub_carry{a, c} = match (M) {
|
||||
{[l](u64)} => {
|
||||
ca := if (hasarch{'SSE4.2'}) { def S = [l]i64; S~~c > S**0 }
|
||||
else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} }
|
||||
a + M~~ca
|
||||
}
|
||||
{_} => a - promote{u64, c != 0}
|
||||
}
|
||||
|
||||
while (1) {
|
||||
x:M = get_modperm_x{}
|
||||
def vrot1 = ifvec{{x} => vshl{x, x, vcount{type{x}}-1}}
|
||||
k1:M = scal{1}
|
||||
os:=o; xo:=x<<k|vrot1{x>>(64-k)}; o=xo&k1; xo=(xo&~k1)|os
|
||||
# Write result word given starting bits
|
||||
def step{b, c} = output{sub_carry{c - b, c & mr}}
|
||||
def step{b, c, m} = step{b&m, c&m}
|
||||
# Fast unrolled iterations
|
||||
mask := mc4
|
||||
@for (k/4) {
|
||||
each{step{x & mask, xo & mask, .}, submasks}
|
||||
mask = advance{mask, s4}
|
||||
}
|
||||
# Single-step for tail
|
||||
mask = mask_tail
|
||||
@for (k%4) {
|
||||
step{x, xo, mask}
|
||||
mask = advance{mask, s1}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def rep_const_bool_odd{k, xp, rp, nw} = {
|
||||
# Every-k-bits mask
|
||||
m:u64 = spaced_mask_of{k}
|
||||
d := cast_i{usz, popc{m}} # == 64/k
|
||||
mask_sh := cast_i{usz, ctz{m}} # == 64%k
|
||||
mask := m<<(k-mask_sh) | 1
|
||||
|
||||
# Transform sending bit i to k*i % 64 by pairwise swaps
|
||||
# Swap data goes in a pre-computed table
|
||||
swtab:*u64 = each{modperm_dat{u64, .}, 1+2*iota{32}}
|
||||
def swap_lens = reverse{2 << iota{5}}
|
||||
swap_data := load{swtab, k>>1}
|
||||
swsel:u64 = ~u64~~0
|
||||
def gsw{l} = {
|
||||
swsel ^= swsel << l # Low l bits out of every 2*l
|
||||
sm := swap_data &~ swsel
|
||||
swap_data &= swsel; swap_data |= swap_data<<l
|
||||
sm
|
||||
}
|
||||
swap_masks := each{gsw, swap_lens}
|
||||
i:usz = 0
|
||||
# Load x, send bit i to position k*i % 64
|
||||
def swap_x = fold_multi{modperm_step, ., swap_lens, swap_masks}
|
||||
def get_swap_x{} = { x := swap_x{load{xp, i}}; ++i; x }
|
||||
|
||||
# Output
|
||||
j:usz = 0
|
||||
def output{rw} = {
|
||||
store{rp, j, rw}
|
||||
++j; if (j==nw) return{1}
|
||||
}
|
||||
o:u64 = 0 # carry
|
||||
# Dedicated loop for 3, shared for other factors
|
||||
if (k == 3) {
|
||||
def unrolled_iter{k} = {
|
||||
def reps{f, n, init} = scan{{x,_}=>f{x}, n**init}
|
||||
def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask}
|
||||
def step{x}{m, n_over} = {
|
||||
b := x & m
|
||||
def os = (64-k) + 1 + iota{n_over}
|
||||
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
|
||||
o = b
|
||||
}
|
||||
while (1) each{step{get_swap_x{}}, masks, (-iota{k}*64)%k}
|
||||
}
|
||||
unrolled_iter{3}
|
||||
} else {
|
||||
rep_const_bool_odd_mask4{u64, k, get_swap_x, output, nw, mask, mask_sh}
|
||||
}
|
||||
}
|
||||
|
||||
# For odd numbers:
|
||||
# - permute each byte sending bit i to position k*i % 8
|
||||
# - replicate each byte by k, making position k*i contain bit i
|
||||
# - mask out those bits and spread over [ k*i, k*(i+1) )
|
||||
# - ...except where it crosses words; handle this overhang separately
|
||||
def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15
|
||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||
def vl = 16; def V = [vl]u8
|
||||
def iV = iota{vl}
|
||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||
def W = [2]u64
|
||||
def {output, flush} = get_boolvec_writer{V, r, rlen}
|
||||
def {output, flush} = get_boolvec_writer{V, r, nw}
|
||||
|
||||
# Swap data goes in a pre-computed table
|
||||
swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}}
|
||||
swap_data := load{swtab, wv>>1}
|
||||
# Within-byte transformation
|
||||
def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
|
||||
ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
|
||||
{t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
|
||||
m4 := V**0xf
|
||||
def perm_x{xv} = {
|
||||
selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
|
||||
}
|
||||
def perm_x = modperm_get_byteperm{selV{swap_data, V**0}}
|
||||
|
||||
# Cases are 3; 5 7; 9 11 13 15
|
||||
if (wv < 4) {
|
||||
# 3: dedicated loop
|
||||
i:usz = 0; while (1) {
|
||||
@ -629,47 +663,16 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||
o := xo & mkV{255 * (iV%8 == 0)} # overhang
|
||||
output{rv | o}
|
||||
}
|
||||
} else { # wv < 32
|
||||
# 9 to 31: extend k*i % 8 transform to k*i % 128 by pairwise swaps
|
||||
# Swap data goes in a pre-computed table
|
||||
def swdat{k} = {
|
||||
def bits = (k*iota{128} >> merge{0, replicate{1<<iota{7}, iota{7}}}) & 1
|
||||
mkV{each{base{2, .}, split{8, bits}}}
|
||||
}
|
||||
swtab:*V = each{swdat, 9+2*iota{12}}
|
||||
} else {
|
||||
# General case
|
||||
def swap_lens = reverse{1 << iota{4}}
|
||||
swap_data := load{swtab, (wv>>1)-4}
|
||||
swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens}
|
||||
# Every-k-bits mask, same as before
|
||||
{m, d} := unaligned_spaced_mask_mod{wv}
|
||||
mask := make{W, m, m>>d|m<<(wv-d)}
|
||||
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
|
||||
# Mask out carry bit
|
||||
mr := [4]u32~~W**(u64~~1<<wv - 1)
|
||||
# State
|
||||
xv:V = V**0; o:=W**0
|
||||
swap_masks := each{{l} => selV{~swap_data, mkV{l+iV%l}}, swap_lens}
|
||||
def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks}
|
||||
i:usz = 0
|
||||
while (1) {
|
||||
# Load xv, send bit i to position wv*i % 128
|
||||
xv = load{*V~~x, i}; ++i
|
||||
def swap_step{l, m} = {
|
||||
xv = (xv & m) | (selV{xv, mkV{iV + (l-2*(iV&l))}} &~ m)
|
||||
}
|
||||
each{swap_step, swap_lens, swap_masks}
|
||||
xv = perm_x{xv}
|
||||
xw := W~~xv
|
||||
def vrot1{x} = vshl{x, x, vcount{type{x}}-1}
|
||||
w1 := W**1
|
||||
os:=o; xo:=xw<<wv|vrot1{xw>>(64-wv)}; o=xo&w1; xo=(xo&~w1)|os
|
||||
# Write wv vectors based on that
|
||||
@for (wv) {
|
||||
b := xw & mask
|
||||
# Handle overhang here; won't fit in a single vector
|
||||
c := xo & mask; cu := [4]u32~~c
|
||||
output{V~~(W~~(cu + (mr&cu > [4]u32**0)) - b)}
|
||||
mask = (mask << (wv - mask_sh)) | (mask >> mask_sh)
|
||||
}
|
||||
}
|
||||
def get_swap_x{} = { xv := perm_x{swap_x{load{*V~~x, i}}}; ++i; xv }
|
||||
# Every-k-bits mask
|
||||
{m, d} := unaligned_spaced_mask_mod{wv}
|
||||
rep_const_bool_odd_mask4{W, wv, {}=>W~~get_swap_x{}, {v}=>output{V~~v}, cdiv{nw, vcount{W}}, m, d}
|
||||
}
|
||||
flush{}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user