Merge pull request #136 from mlochbaum/boolrep
Faster const/bool for 64<const≤256
This commit is contained in:
commit
e09fb53fb9
@ -45,6 +45,7 @@
|
|||||||
// COULD use pdep or similar to avoid overhead on small results
|
// COULD use pdep or similar to avoid overhead on small results
|
||||||
// Otherwise, factor into power of 2 times odd
|
// Otherwise, factor into power of 2 times odd
|
||||||
// COULD fuse 2×odd, since 2/odd/ has a larger intermediate
|
// COULD fuse 2×odd, since 2/odd/ has a larger intermediate
|
||||||
|
// 𝕨≤256, AVX2: Modular permutation with shift-based masks
|
||||||
// Other typed 𝕩 uses +`, or lots of Singeli
|
// Other typed 𝕩 uses +`, or lots of Singeli
|
||||||
// Fixed shuffles, factorization, partial shuffles, self-overlapping
|
// Fixed shuffles, factorization, partial shuffles, self-overlapping
|
||||||
// Otherwise, cell-by-cell copying
|
// Otherwise, cell-by-cell copying
|
||||||
@ -779,12 +780,29 @@ B slash_c2(B t, B w, B x) {
|
|||||||
if (xl == 0) {
|
if (xl == 0) {
|
||||||
u64* xp = bitany_ptr(x);
|
u64* xp = bitany_ptr(x);
|
||||||
u64* rp; r = m_bitarrv(&rp, s);
|
u64* rp; r = m_bitarrv(&rp, s);
|
||||||
#if SINGELI
|
#if SINGELI_AVX2
|
||||||
if (wv <= 64) si_constrep_bool(wv, xp, rp, s);
|
if (wv <= 256) si_constrep_bool(wv, xp, rp, s);
|
||||||
else
|
#elif SINGELI
|
||||||
|
if (wv <= 128) si_constrep_bool(wv, xp, rp, s);
|
||||||
|
#else
|
||||||
|
if (wv <= 64) { BOOL_REP_XOR_SCAN(wv) }
|
||||||
#endif
|
#endif
|
||||||
if (wv <= 256) { BOOL_REP_XOR_SCAN(wv) }
|
else {
|
||||||
else { BOOL_REP_OVER(wv, xlen) }
|
// Like BOOL_REP_OVER but predictable
|
||||||
|
u64 ri=0, c=0; usz j=0;
|
||||||
|
usz n=wv/64-1;
|
||||||
|
for (usz i = 0; i < xlen; i++) {
|
||||||
|
u64 v = -(1 & xp[i/64]>>(i%64));
|
||||||
|
u64 r0 = c ^ ((v^c) << (ri%64));
|
||||||
|
c = v;
|
||||||
|
ri += wv; usz e = ri/64;
|
||||||
|
rp[e-1] = v; // This allows the loop to be constant-length
|
||||||
|
rp[j] = r0;
|
||||||
|
for (usz k=0; k<n; k++) rp[j+1+k] = v;
|
||||||
|
j = e;
|
||||||
|
}
|
||||||
|
if (ri%64) rp[j] = c;
|
||||||
|
}
|
||||||
goto atmW_maybesh;
|
goto atmW_maybesh;
|
||||||
} else {
|
} else {
|
||||||
u8 xk = xl-3;
|
u8 xk = xl-3;
|
||||||
|
|||||||
@ -296,7 +296,7 @@ if_inline (hasarch{'AARCH64'}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
||||||
assert{wv >= 2}; assert{wv <= 64}
|
assert{wv >= 2}
|
||||||
nw := cdiv{rlen, 64}
|
nw := cdiv{rlen, 64}
|
||||||
if (wv&1 == 0) {
|
if (wv&1 == 0) {
|
||||||
p := ctz{wv | 8} # Power of two for second replicate
|
p := ctz{wv | 8} # Power of two for second replicate
|
||||||
@ -321,6 +321,23 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
|||||||
rep_const_bool{}(wf, x, t, tlen)
|
rep_const_bool{}(wf, x, t, tlen)
|
||||||
rep_const_bool{}(wq, t, r, rlen)
|
rep_const_bool{}(wq, t, r, rlen)
|
||||||
}
|
}
|
||||||
|
} else if (not hasarch{'AVX2'} and wv >= 64) {
|
||||||
|
# Identical to the non-Singeli case for large wv
|
||||||
|
# but we assume n==0 and drop the loop
|
||||||
|
assert{wv < 128}
|
||||||
|
i:usz = 0
|
||||||
|
ri:usz = 0; j:usz = 0 # Bit index in r; ri/64
|
||||||
|
c:u64 = 0
|
||||||
|
while (j < nw-1) { # e<=j+2 below, so e-1<nw
|
||||||
|
v := -(1 & load{x, i/64}>>(i%64)); ++i
|
||||||
|
r0:= c ^ ((v^c) << (ri%64))
|
||||||
|
c = v
|
||||||
|
ri+= wv; e:= ri/64
|
||||||
|
store{r, e-1, c}
|
||||||
|
store{r, j, r0}
|
||||||
|
j = e
|
||||||
|
}
|
||||||
|
if (ri%64 != 0) store{r, j, c}
|
||||||
} else {
|
} else {
|
||||||
rep_const_bool_odd{wv, x, r, nw}
|
rep_const_bool_odd{wv, x, r, nw}
|
||||||
}
|
}
|
||||||
@ -472,10 +489,8 @@ def modperm_dat{T, k} = {
|
|||||||
def w = width{T}
|
def w = width{T}
|
||||||
def i = iota{lb{w}}
|
def i = iota{lb{w}}
|
||||||
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
|
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
|
||||||
match (T) {
|
def E = match (T) { {[_]E} => E; {_} => T }
|
||||||
{[_]E} => make{T, each{base{2,.}, split{width{E}, bits}}}
|
each{{x} => E~~base{2,x}, split{width{E}, bits}}
|
||||||
{_} => T~~base{2, bits}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Permutation step evaluators
|
# Permutation step evaluators
|
||||||
# shift takes a top-half mask; others take both-halves
|
# shift takes a top-half mask; others take both-halves
|
||||||
@ -555,37 +570,41 @@ def proc_mod_dat{swap_data:W} = {
|
|||||||
else tup{8, get_byteperm{}}
|
else tup{8, get_byteperm{}}
|
||||||
}
|
}
|
||||||
# Fill in higher steps
|
# Fill in higher steps
|
||||||
def get_mod_permuter{} = {
|
def get_mod_permuter{width} = {
|
||||||
def get_swap{l} = {
|
def get_swap{l} = {
|
||||||
def mask = extract_modperm_mask{swap_data, swap_lane, l}
|
def mask = extract_modperm_mask{swap_data, swap_lane, l}
|
||||||
modperm_shuf_step{., l, mask}
|
modperm_shuf_step{., l, mask}
|
||||||
}
|
}
|
||||||
def swap_x = on_len_range{get_swap, partwidth, ww}
|
def swap_x = on_len_range{get_swap, partwidth, width}
|
||||||
{x} => partperm{swap_x{x}}
|
{x} => partperm{swap_x{x}}
|
||||||
}
|
}
|
||||||
|
def get_mod_permuter{} = get_mod_permuter{ww}
|
||||||
tup{partperm, get_mod_permuter}
|
tup{partperm, get_mod_permuter}
|
||||||
}
|
}
|
||||||
|
|
||||||
def rep_const_bool_odd{k, x, r, nw} = {
|
def MP = if (has_simd) [if (hasarch{'AVX2'}) 4 else 2]u64 else u64
|
||||||
def avx2 = hasarch{'AVX2'}
|
swtab:*u64 = join{each{modperm_dat{MP, .}, 1+2*iota{32}}}
|
||||||
def W = if (has_simd) [if (avx2) 4 else 2]u64 else u64
|
|
||||||
|
|
||||||
|
def rep_const_bool_odd{k, x, r, nw} = {
|
||||||
|
def W = MP
|
||||||
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
|
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
|
||||||
xp := *W~~x
|
xp := *W~~x
|
||||||
def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} }
|
def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} }
|
||||||
|
|
||||||
# Modular permutation: small-k cases may use a limited permutation
|
# Modular permutation: small-k cases may use a limited permutation
|
||||||
# on bytes or 32-bit ints; general case uses the whole thing
|
# on bytes or 32-bit ints; general case uses the whole thing
|
||||||
swtab:*W = each{modperm_dat{W, .}, 1+2*iota{32}}
|
swap_data := load{*W~~swtab, (k%64)>>1}
|
||||||
swap_data := load{swtab, k>>1}
|
|
||||||
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
|
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
|
||||||
|
|
||||||
def sp_max = if (any_sel) 8 else 4
|
def sp_max = if (any_sel) 8 else 4
|
||||||
if (k < sp_max) {
|
if (k < sp_max) {
|
||||||
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
|
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
|
||||||
} else {
|
} else if (not hasarch{'AVX2'} or k < 64) {
|
||||||
def get_swap_x = getter{get_full_permute{}}
|
def get_swap_x = getter{get_full_permute{}}
|
||||||
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}}
|
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}}
|
||||||
|
} else {
|
||||||
|
def get_swap_x = getter{get_full_permute{64}}
|
||||||
|
rep_const_bool_odd_loose_mask{W, k, get_swap_x, output}
|
||||||
}
|
}
|
||||||
flush{}
|
flush{}
|
||||||
}
|
}
|
||||||
@ -599,6 +618,7 @@ def rep_const_bool_odd_mask4{
|
|||||||
} = {
|
} = {
|
||||||
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
|
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
|
||||||
def scal = ifvec{{v} => M**v}
|
def scal = ifvec{{v} => M**v}
|
||||||
|
assert{k < 64}
|
||||||
|
|
||||||
# Fundamental operation: shifts act as order-k cyclic group on masks
|
# Fundamental operation: shifts act as order-k cyclic group on masks
|
||||||
def advance{m, sh} = advance_spaced_mask{k, m, sh}
|
def advance{m, sh} = advance_spaced_mask{k, m, sh}
|
||||||
@ -729,4 +749,55 @@ def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Odd factors larger than 64
|
||||||
|
# AVX2-only because scalar should be about as good otherwise
|
||||||
|
def rep_const_bool_odd_loose_mask{V=[vl==4](u64), k, get_modperm_x, output if hasarch{'AVX2'}} = {
|
||||||
|
assert{k > 64}
|
||||||
|
# Distance from end to previous row boundary (-k <= q < 0)
|
||||||
|
q := -make{V, 64*(1+iota{vl})}
|
||||||
|
def q_mod{} = { d:=q+V**k; q = blend_top{q,d, d} }
|
||||||
|
o:u64 = width{V}; while (o>k) { o-=k; q_mod{} }
|
||||||
|
|
||||||
|
km:= k%64
|
||||||
|
i:usz = 0; iv:usz = 0 # Words and vectors completed
|
||||||
|
def step{perm, diff} = {
|
||||||
|
# Mod-64 mask with 1 bit per word
|
||||||
|
m:= V**1 << (q & V**63)
|
||||||
|
# Indicator of which bits are actual boundaries
|
||||||
|
def S = ty_s{V}; a:= S**(-65) < S~~q
|
||||||
|
q-= V**o; q_mod{}
|
||||||
|
# Set to bit from perm, but below a&m xor with diff to get previous
|
||||||
|
base:= (m & perm) == m
|
||||||
|
md:= (a & m) & diff
|
||||||
|
output{base ^ (md + (md==m))}
|
||||||
|
}
|
||||||
|
{xp, xd, perm, diff}:= 4**(V**0)
|
||||||
|
while (1) {
|
||||||
|
# get_modperm_x permutes each 64-bit word
|
||||||
|
# Each iteration of this loop handles one permuted word
|
||||||
|
def first = shuf{., 4**0}
|
||||||
|
if (i%4 == 0) {
|
||||||
|
xp = get_modperm_x{}
|
||||||
|
# Shift by 1, or k%64 in mod-space
|
||||||
|
# Then the low bit of each word has to be moved to the next
|
||||||
|
# As before, first bit is wrong but unused
|
||||||
|
xl:= xp>>(64-km)
|
||||||
|
xo:= (xp<<km | (xl &~ V**1)) | (shuf{xl, 3,0,1,2} & V**1)
|
||||||
|
xd = xo ^ xp
|
||||||
|
perm = first{xp}; diff = first{xd}
|
||||||
|
} else {
|
||||||
|
def upd{xq, q} = {
|
||||||
|
xq = shuf{xq, 1,2,3,0} # Next word
|
||||||
|
qs:= q; q = first{xq}
|
||||||
|
blend_hom{q, qs, iota{V} < V**(i%4)}
|
||||||
|
}
|
||||||
|
step{upd{xp, perm}, upd{xd, diff}} # Do boundary between iterations
|
||||||
|
++iv
|
||||||
|
}
|
||||||
|
i+= k
|
||||||
|
ip:= iv; iv = i/4
|
||||||
|
@for (iv - ip) step{perm, diff}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export{'si_constrep_bool', rep_const_bool{}}
|
export{'si_constrep_bool', rep_const_bool{}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user