Merge pull request #116 from mlochbaum/repbit
Fast constant replicate boolean
This commit is contained in:
commit
34ac49fd21
@ -38,8 +38,13 @@
|
||||
// COULD consolidate refcount updates for nested 𝕩
|
||||
|
||||
// Replicate by constant
|
||||
// Boolean uses pdep, ≠`, or overwriting
|
||||
// SHOULD make a shift/mask replacement for pdep
|
||||
// Boolean uses specialized small-𝕨 methods, ≠`, or overwriting
|
||||
// 𝕨≤64: Singeli generic and SIMD methods
|
||||
// 𝕨=2,4,8: Various shift, shuffle, and zip-based loops
|
||||
// odd 𝕨: Modular permutation
|
||||
// COULD use pdep or similar to avoid overhead on small results
|
||||
// Otherwise, factor into power of 2 times odd
|
||||
// COULD fuse 2×odd, since 2/odd/ has a larger intermediate
|
||||
// Other typed 𝕩 uses +`, or lots of Singeli
|
||||
// Fixed shuffles, factorization, partial shuffles, self-overlapping
|
||||
// Otherwise, cell-by-cell copying
|
||||
@ -82,6 +87,8 @@
|
||||
extern void (*const si_scan_max_i32)(int32_t* v0,int32_t* v1,uint64_t v2);
|
||||
#define SINGELI_FILE slash
|
||||
#include "../utils/includeSingeli.h"
|
||||
extern uint64_t* const si_spaced_masks;
|
||||
#define get_spaced_mask(i) si_spaced_masks[i-1]
|
||||
#define SINGELI_FILE replicate
|
||||
#include "../utils/includeSingeli.h"
|
||||
#endif
|
||||
@ -765,38 +772,9 @@ B slash_c2(B t, B w, B x) {
|
||||
if (xl == 0) {
|
||||
u64* xp = bitarr_ptr(x);
|
||||
u64* rp; r = m_bitarrv(&rp, s);
|
||||
#if FAST_PDEP
|
||||
if (wv <= 52) {
|
||||
#if SINGELI
|
||||
u64 m = si_spaced_masks[wv-1];
|
||||
#else
|
||||
u64 m = (u64)-1 / (((u64)1<<wv)-1);
|
||||
#endif
|
||||
u64 xw = 0;
|
||||
usz d = POPC(m); // == 64/wv
|
||||
if (m & 1) { // Power of two
|
||||
for (usz i=-1, j=0; j<BIT_N(s); j++) {
|
||||
xw >>= d;
|
||||
if ((j&(wv-1))==0) xw = xp[++i];
|
||||
u64 rw = _pdep_u64(xw, m);
|
||||
rp[j] = (rw<<wv)-rw;
|
||||
}
|
||||
} else {
|
||||
usz q = CTZ(m); // == 64%wv
|
||||
m = m<<(wv-q) | 1;
|
||||
u64 mt = (u64)1<<(d+1); // Bit d+1 may be needed, isn't pdep-ed
|
||||
usz tsh = d*wv-(d+1);
|
||||
for (usz xi=0, o=0, j=0; j<BIT_N(s); j++) {
|
||||
xw = loadu_u64((u64*)((u8*)xp+xi/8)) >> (xi%8);
|
||||
u64 ex = (xw&mt)<<tsh;
|
||||
u64 rw = _pdep_u64(xw, m);
|
||||
rp[j] = ((rw-ex)<<(wv-o))-(rw>>o|(xw&1));
|
||||
o += q;
|
||||
bool oo = o>=wv; xi+=d+oo; o-=wv&-oo;
|
||||
}
|
||||
}
|
||||
goto atmW_maybesh;
|
||||
}
|
||||
#if SINGELI
|
||||
if (wv <= 64) si_constrep_bool(wv, xp, rp, s);
|
||||
else
|
||||
#endif
|
||||
if (wv <= 256) { BOOL_REP_XOR_SCAN(wv) }
|
||||
else { BOOL_REP_OVER(wv, xlen) }
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
include './base'
|
||||
include './mask'
|
||||
include './spaced'
|
||||
|
||||
def ind_types = tup{i8, i16, i32}
|
||||
def dat_types = tup{...ind_types, u64}
|
||||
@ -58,7 +59,10 @@ exportT{'si_replicate_scan', flat_table{rep_by_scan, ind_types, dat_types}}
|
||||
|
||||
|
||||
# Constant replicate
|
||||
if_inline (not (hasarch{'AVX2'} or hasarch{'AARCH64'})) {
|
||||
def incl{a,b} = slice{iota{b+1},a}
|
||||
def basic_rep = incl{2, 7}
|
||||
|
||||
if_inline (not (hasarch{'SSSE3'} or hasarch{'AARCH64'})) {
|
||||
|
||||
fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = {
|
||||
rep_by_scan{T, cast_i{usz,wv}, x, r, cast_i{usz, wv*n}}
|
||||
@ -66,21 +70,18 @@ fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = {
|
||||
|
||||
} else {
|
||||
|
||||
def incl{a,b} = slice{iota{b+1},a}
|
||||
def has_bytesel_128 = not hasarch{'AVX2'}
|
||||
|
||||
# 1+˝∨`⌾⌽0=div|⌜range
|
||||
def makefact{divisor, range} = {
|
||||
def t = table{{a,b}=>0==b%a, divisor, range}
|
||||
fold{+, 1, reverse{scan{|, reverse{t}}}}
|
||||
}
|
||||
def basic_rep = incl{2, 7}
|
||||
def fact_size = 128
|
||||
def fact_inds = slice{iota{fact_size},8}
|
||||
def fact_tab = makefact{basic_rep, fact_inds}
|
||||
factors:*u8 = fact_tab
|
||||
|
||||
|
||||
|
||||
def sdtype = [arch_defvw/8]i8 # shuf data type
|
||||
def get_shufs{step, wv, len} = {
|
||||
def i = iota{len*step}
|
||||
@ -90,7 +91,7 @@ def get_shuf_data{wv, len} = get_shufs{vcount{sdtype}, wv, len} # [len] byte-sel
|
||||
def get_shuf_data{wv} = get_shuf_data{wv, wv}
|
||||
|
||||
# all shuffle vectors for 𝕨≤7
|
||||
def special_2 = ~hasarch{'AARCH64'} # handle 2 specially on x86-64
|
||||
def special_2 = not has_bytesel_128 # handle 2 specially on AVX2
|
||||
def rcsh_vals = slice{basic_rep, special_2}
|
||||
rcsh_offs:*u8 = shiftright{0, scan{+,rcsh_vals}}
|
||||
rcsh_data:*i8 = join{join{each{get_shuf_data, rcsh_vals}}}
|
||||
@ -106,7 +107,7 @@ def read_shuf_vecs{l, ellw:(u64), shp:*V} = { # tuple of byte selectors in 1<<el
|
||||
r:=each{bind{~~,[32]i8},mzip128{s, s + X**1}}
|
||||
r
|
||||
}
|
||||
def double{x:X if hasarch{'AARCH64'}} = {
|
||||
def double{x:X if has_bytesel_128} = {
|
||||
s:= x+x
|
||||
zip{s, s + X**1}
|
||||
}
|
||||
@ -167,7 +168,7 @@ if_inline (hasarch{'AVX2'}) {
|
||||
|
||||
def rep_const_shuffle{wv, xv:*V, rv:*V, n:(u64)} = rep_const_shuffle{wv, get_rep_iter{V, wv}, xv, rv, n}
|
||||
|
||||
} else if_inline (hasarch{'AARCH64'}) {
|
||||
} else { # has_bytesel_128
|
||||
|
||||
def rep_iter_from_sh{sh}{x, gen} = {
|
||||
each{{s} => gen{sel{[16]u8, x, s}}, sh}
|
||||
@ -191,7 +192,7 @@ fn rep_const_shuffle_partial4(wv:u64, ellw:u64, x:*i8, r:*i8, n:u64) : void = {
|
||||
def wvb = wv << ellw
|
||||
def hs = (h*step) / wvb # Actual step size in argument elements
|
||||
def shufbase{i if hasarch{'AVX2'}} = shuf{[4]u64, load{*V~~(x+i)}, 4b1010}
|
||||
def shufbase{i if hasarch{'AARCH64'}} = load{*V~~(x+i)}
|
||||
def shufbase{i if has_bytesel_128} = load{*V~~(x+i)}
|
||||
def shufrun{a, s} = sel{[16]i8, a, s} # happens to be the same across AVX2 & NEON
|
||||
|
||||
i:u64 = 0
|
||||
@ -284,3 +285,461 @@ fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = {
|
||||
}
|
||||
|
||||
exportT{'si_constrep', each{rep_const, dat_types}}
|
||||
|
||||
|
||||
|
||||
# Constant replicate on boolean
|
||||
def any_sel = hasarch{'SSSE3'} or hasarch{'AARCH64'}
|
||||
if_inline (hasarch{'AARCH64'}) {
|
||||
def __shl{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,b}
|
||||
def __shr{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,-b}
|
||||
}
|
||||
|
||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
|
||||
assert{wv >= 2}; assert{wv <= 64}
|
||||
nw := cdiv{rlen, 64}
|
||||
if (wv&1 == 0) {
|
||||
p := ctz{wv | 8} # Power of two for second replicate
|
||||
wf := wv>>p
|
||||
if (wf == 1) {
|
||||
rep_const_bool_div8{wv, x, r, nw}
|
||||
} else if (any_sel and p == 3 and wv <= 8*select{basic_rep, -1}) {
|
||||
# (higher wv double-factors, which doesn't work with in-place pointers)
|
||||
tlen := rlen / wf
|
||||
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||
rep_const_bool{}(8, x, t, tlen)
|
||||
rep_const{select{dat_types,0}}(promote{u64,wf}, *void~~t, *void~~r, promote{u64,tlen/8})
|
||||
} else {
|
||||
tlen := rlen >> p
|
||||
wq := usz~~1 << p
|
||||
if (p == 1 and (not any_sel or wv>=24)) {
|
||||
# Expanding odd second is faster
|
||||
tlen = rlen / wf
|
||||
t:=wf; wf=wq; wq=t
|
||||
}
|
||||
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||
rep_const_bool{}(wf, x, t, tlen)
|
||||
rep_const_bool{}(wq, t, r, rlen)
|
||||
}
|
||||
} else {
|
||||
rep_const_bool_odd{wv, x, r, nw}
|
||||
}
|
||||
}
|
||||
|
||||
def rep_const_bool_div8{wv, x, r, nw} = { # wv in 2,4,8
|
||||
def run{k} = {
|
||||
# 2 -> 64w0x33, 12 -> 64w0x000f, etc.
|
||||
def getm{sh} = base{2, iota{64}&sh == 0}
|
||||
def osh{v, s} = v | v<<s
|
||||
def expand = match (k) {
|
||||
{2} => fold{
|
||||
{v, sh} => osh{v, sh} & getm{sh},
|
||||
., 1 << reverse{iota{5}}
|
||||
}
|
||||
{4} => fold{
|
||||
{v, sh} => osh{osh{v, sh}, 2*sh} & getm{sh},
|
||||
., tup{12, 3}
|
||||
}
|
||||
{8} => {
|
||||
def mult = base{1<<7, 8**1}
|
||||
{x} => (x | ((x&~1) * mult)) & 64w0x01
|
||||
}
|
||||
}
|
||||
@for (xt in *ty_u{64/k}~~x, r over nw) {
|
||||
def v = expand{promote{u64, xt}}
|
||||
r = v<<k - v
|
||||
}
|
||||
}
|
||||
def cases{k} = if (wv==k) run{k} else if (k<8) cases{2*k}
|
||||
cases{2}
|
||||
}
|
||||
|
||||
def get_boolvec_writer{V, r, nw} = {
|
||||
def vwords = width{V}/64
|
||||
rv := *V~~r
|
||||
re := rv + nw / vwords
|
||||
rs := rv + cdiv{nw, vwords}
|
||||
# Avoid reading/processing an extra word past the end of input
|
||||
def done = makelabel{}
|
||||
def check_done{} = if (rv == rs) goto{done}
|
||||
# If the last result is partial, jump to flush to write it
|
||||
last_res:V = V**0
|
||||
def l_flush = makelabel{}
|
||||
def output{v:(V)} = {
|
||||
last_res = v
|
||||
if (rv == re) goto{l_flush}
|
||||
store{rv, 0, v}; ++rv
|
||||
}
|
||||
def flush{} = {
|
||||
setlabel{l_flush}
|
||||
q := nw & (vwords-1)
|
||||
if (q != 0) homMaskStoreF{rv, V~~maskOf{re_el{u64,V}, q}, last_res}
|
||||
setlabel{done}
|
||||
}
|
||||
tup{output, check_done, flush}
|
||||
}
|
||||
def get_boolvec_writer{T=(u64), r:*T, nw} = {
|
||||
def done = makelabel{}
|
||||
j:usz = 0
|
||||
def output{rw} = {
|
||||
store{r, j, rw}
|
||||
++j; if (j==nw) goto{done}
|
||||
}
|
||||
def flush{} = setlabel{done}
|
||||
tup{output, {}=>{}, flush}
|
||||
}
|
||||
|
||||
def rep_const_bool_div8{wv, x, r, nw if has_simd} = {
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
def vl = if (avx2) 32 else 16
|
||||
def V = [vl]u8
|
||||
def iV = iota{vl}
|
||||
def mkV = make{V, .}
|
||||
# Not the same as a 1-byte shift but extra bits are always masked away
|
||||
def H = re_el{u16, V}
|
||||
def __shl{x:(V), a} = V~~(H~~x << a)
|
||||
def __shr{x:(V), a} = V~~(H~~x >> a)
|
||||
|
||||
def {output, check_done, flush} = get_boolvec_writer{V, r, nw}
|
||||
def run24{x, get_halves} = {
|
||||
i:usz = 0; while (1) { check_done{}
|
||||
xv := load{*V~~(x+i)}; ++i
|
||||
def getr = zip128{...get_halves{xv}, .}
|
||||
output{V~~getr{0}}
|
||||
output{V~~getr{1}}
|
||||
}
|
||||
}
|
||||
def run24{x, pre, exh, sh} = run24{x,
|
||||
{xv} => { p := pre{xv}; tup{exh{p}, exh{p>>sh}} }
|
||||
}
|
||||
def run8{rep_bytes} = {
|
||||
i:usz = 0; while (1) { check_done{}
|
||||
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i
|
||||
xv := if (avx2) pair{xh, xh} else xh
|
||||
xe := rep_bytes{xv}
|
||||
output{(xe & mkV{1 << (iV % 8)}) > V**0}
|
||||
}
|
||||
}
|
||||
if (not any_sel) {
|
||||
if (wv == 2) {
|
||||
run24{*V~~x, {xv} => {
|
||||
def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<<l) }
|
||||
a := swap{swap{xv, 2, 0x0c}, 1, 0x22}
|
||||
m := V**0x55
|
||||
h0 := a & m; o0 := h0 | h0<<1
|
||||
h1 := a &~ m; o1 := h1 | h1>>1
|
||||
tup{o0, o1}
|
||||
}}
|
||||
} else if (wv == 4) {
|
||||
def pre{xv} = {
|
||||
a := xv ^ ((xv & V**0x55) << 1)
|
||||
e := zip128{a, a>>4, 0}
|
||||
}
|
||||
def out{h} = (-(h & V**1)) ^ (-(h<<3 & V**0x10))
|
||||
run24{*u64~~x, pre, out, 2}
|
||||
} else { # wv == 8
|
||||
def z{x} = zip{x,x,0}
|
||||
run8{{x} => z{z{z{x}}}}
|
||||
}
|
||||
} else { # any_sel
|
||||
def selH = sel{[16]u8, ., .}
|
||||
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
|
||||
def id{xv} = xv
|
||||
if (wv == 2) {
|
||||
def init = if (avx2) shuf{[4]u64, ., 4b3120} else id
|
||||
# Expander for half byte
|
||||
def tabr = makeTab{tr_iota{2*iota{4}} * 2b11}
|
||||
m4 := V**0xf
|
||||
run24{*V~~x, init, {x} => tabr{x & m4}, 4}
|
||||
} else if (wv == 4) {
|
||||
# Unzip 32-bit elements (result lanes) across AVX2 lanes
|
||||
def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id
|
||||
def init{xv} = { u:=pre{xv}; zip128{u,u,0} }
|
||||
# Expander for two bits in either bottom or next-to-bottom position
|
||||
def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111}
|
||||
m2 := mkV{2b11 << (2*(iV%2))}
|
||||
def exh{x} = re_el{u16, V}~~tabr{x & m2}
|
||||
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh, 4}
|
||||
} else { # wv == 8
|
||||
run8{selH{., mkV{iV // 8}}}
|
||||
}
|
||||
}
|
||||
flush{}
|
||||
}
|
||||
|
||||
def sel_imm{V=[4]T, x, {...inds} if hasarch{'SSE2'} and (T==u32 or T==u64)} = {
|
||||
assert{length{inds} == 4}
|
||||
shuf{V, x, base{4, inds}}
|
||||
}
|
||||
def sel_imm{V=([16]u8), x:X, {...inds}} = {
|
||||
def I = re_el{u8,X}; def [n]_ = I
|
||||
def l = length{inds}
|
||||
assert{l==16 or l==n}
|
||||
sel{V, x, make{I, cycle{n, inds}}}
|
||||
}
|
||||
|
||||
# Data for the permutation that sends bit i to k*i % width{T}
|
||||
def modperm_dat{T, k} = {
|
||||
def w = width{T}
|
||||
def i = iota{lb{w}}
|
||||
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
|
||||
match (T) {
|
||||
{[_]E} => make{T, each{base{2,.}, split{width{E}, bits}}}
|
||||
{_} => T~~base{2, bits}
|
||||
}
|
||||
}
|
||||
# Permutation step evaluators
|
||||
# shift takes a top-half mask; others take both-halves
|
||||
def modperm_shift_step{x, l, m} = {
|
||||
def d = (x ^ x<<l) & m
|
||||
x ^ (d | d>>l)
|
||||
}
|
||||
def modperm_rot_step{x:T, l=(width{T}/2), m} = {
|
||||
(x &~ m) | ((x<<l | x>>l) & m)
|
||||
}
|
||||
def modperm_shuf_step{x:V=[_]T, l, m if l%8==0} = {
|
||||
(x &~ m) | (swap_elts{x, l/8} & m)
|
||||
}
|
||||
# Reverse each pair of elements
|
||||
def swap_elts{x:V=[_]_, el_bytes} = {
|
||||
def selx{n, wd} = {
|
||||
def l = el_bytes/wd
|
||||
def i = iota{n}
|
||||
def rev = i + (l - 2*(i&l))
|
||||
sel_imm{[n]ty_u{wd*8}, x, rev}
|
||||
}
|
||||
if (hasarch{'SSE2'} and el_bytes >= 4) {
|
||||
selx{4, max{4, el_bytes/2}}
|
||||
} else if (any_sel) {
|
||||
selx{16, 1}
|
||||
} else {
|
||||
def sh = el_bytes*8
|
||||
xe := re_el{ty_u{2*sh}, V} ~~ x
|
||||
V~~(xe<<sh | xe>>sh)
|
||||
}
|
||||
}
|
||||
# Extract swap functions from modperm_dat
|
||||
def extract_modperm_mask{full:W, lane, l} = {
|
||||
def f = l == 128 # Use full width
|
||||
def w = if (l<32 or not hasarch{'SSE2'}) 8
|
||||
else if (f) 64 else 32 # Shuffle element width
|
||||
def n = (if (f) 256 else 128) / w # and number of elements
|
||||
def o = l / w # Extract [o,2*o)
|
||||
def i = o + iota{n}%o
|
||||
sel_imm{[n]ty_u{w}, if (f) full else lane, i}
|
||||
}
|
||||
def proc_mod_dat{swap_data:W} = {
|
||||
def ww = width{W}; def avx2 = ww==256
|
||||
def on_len_range{get_swap, lo, hi} = {
|
||||
def lens = reverse{1 << slice{iota{lb{hi}}, lb{lo}}}
|
||||
fold{{x,s}=>s{x}, ., each{get_swap, lens}}
|
||||
}
|
||||
swap_lane := if (avx2) shuf{W, swap_data, 4b1010} else swap_data
|
||||
# Transform width-w units with shifts only
|
||||
def get_shiftperm{data, w} = {
|
||||
dat := data
|
||||
bot:W = W**base{2, cycle{64, replicate{w, tup{1,0}}}}
|
||||
def gsw{l} = {
|
||||
bot ^= bot << l # Low l bits out of every 2*l
|
||||
sm := dat &~ bot
|
||||
dat &= bot; dat |= dat<<l
|
||||
if (l<32) modperm_shift_step{., l, sm}
|
||||
else modperm_rot_step {., l, sm|sm>>l}
|
||||
}
|
||||
on_len_range{gsw, 2, w}
|
||||
}
|
||||
# Within-byte transformation
|
||||
def get_byteperm{} = {
|
||||
def V = re_el{u8, W}
|
||||
def sw_bytes = sel{[16]u8, swap_lane, V**0}
|
||||
m4 := W~~V**0xf
|
||||
t0 := fold{{v,a}=>modperm_shift_step{v,...a}, W~~make{V,iota{vcount{V}}%16}, tup{
|
||||
tup{4, sw_bytes &~ m4},
|
||||
tup{2, ({v} => v | v<<4){sw_bytes & W~~V**0xc}}
|
||||
}}
|
||||
t4 := (t0&m4)<<4 | (t0&~m4)>>4
|
||||
def selI{v,i} = sel{[16]u8, v, V~~i}
|
||||
{xv} => selI{t0, xv & m4} | selI{t4, (xv>>4) & m4}
|
||||
}
|
||||
def {partwidth, partperm} = {
|
||||
def sh{w, d} = tup{w, get_shiftperm{d, w}}
|
||||
if (ww==64) sh{64, swap_data}
|
||||
else if (not any_sel) sh{32, shuf{[4]u32, swap_data, 4b0000}}
|
||||
else tup{8, get_byteperm{}}
|
||||
}
|
||||
# Fill in higher steps
|
||||
def get_mod_permuter{} = {
|
||||
def get_swap{l} = {
|
||||
def mask = extract_modperm_mask{swap_data, swap_lane, l}
|
||||
modperm_shuf_step{., l, mask}
|
||||
}
|
||||
def swap_x = on_len_range{get_swap, partwidth, ww}
|
||||
{x} => partperm{swap_x{x}}
|
||||
}
|
||||
tup{partperm, get_mod_permuter}
|
||||
}
|
||||
|
||||
def rep_const_bool_odd{k, x, r, nw} = {
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
def W = if (has_simd) [if (avx2) 4 else 2]u64 else u64
|
||||
|
||||
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
|
||||
xp := *W~~x
|
||||
def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} }
|
||||
|
||||
# Modular permutation: small-k cases may use a limited permutation
|
||||
# on bytes or 32-bit ints; general case uses the whole thing
|
||||
swtab:*W = each{modperm_dat{W, .}, 1+2*iota{32}}
|
||||
swap_data := load{swtab, k>>1}
|
||||
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
|
||||
|
||||
def sp_max = if (any_sel) 8 else 4
|
||||
if (k < sp_max) {
|
||||
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
|
||||
} else {
|
||||
def get_swap_x = getter{get_full_permute{}}
|
||||
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}}
|
||||
}
|
||||
flush{}
|
||||
}
|
||||
|
||||
# General-case loop for odd replication factors
|
||||
def rep_const_bool_odd_mask4{
|
||||
M, # read/write type
|
||||
k, # replication factor
|
||||
get_modperm_x, # permuted input
|
||||
output, n, # output, number of writes
|
||||
} = {
|
||||
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
|
||||
def scal = ifvec{{v} => M**v}
|
||||
|
||||
# Fundamental operation: shifts act as order-k cyclic group on masks
|
||||
def advance{m, sh} = advance_spaced_mask{k, m, sh}
|
||||
# Double a cumulative mask, shift combination
|
||||
# If s advances l iterations, mc combines iterations iota{l}
|
||||
def double_gen{comb}{{mc, s}} = {
|
||||
def mn = comb{mc, advance{mc,s}}
|
||||
def ss = s+s
|
||||
tup{mn, ss - (k &- (ss>k))}
|
||||
}
|
||||
# Starting word mask and single-word shift
|
||||
{mask, mask_sh} := unaligned_spaced_mask_mod{k}
|
||||
# Mask and shift for one iteration
|
||||
def fillmask{T} = match (T) {
|
||||
{(u64)} => tup{mask, mask_sh}
|
||||
{[2]E} => double_gen{make{T,...}}{fillmask{E}}
|
||||
{[4]E} => double_gen{pair}{fillmask{[2]E}}
|
||||
}
|
||||
{sm0, s1} := fillmask{M}
|
||||
# Combined mask for 4 iterations, and shift to advance 4
|
||||
def double = double_gen{|}
|
||||
{mc4, s4} := double{double{tup{sm0, s1}}}
|
||||
# Submasks pick one mask out of a combination of 4
|
||||
def or_adv{m, s} = { m |= advance{m,s} }
|
||||
@for (min{k/4 - 1, n/4}) or_adv{sm0,s4}
|
||||
submasks := scan{advance, tup{sm0, ...3**s1}}
|
||||
mask_tail := advance{sm0, s4} &~ sm0
|
||||
|
||||
# Carry: shifting and word-crossing is done on the initial permuted x
|
||||
# No need to carry across input words since they align with output words
|
||||
# First bit of each word in xo below is wrong, but it doesn't matter!
|
||||
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
||||
def sub_carry{a, c} = match (M) {
|
||||
{[l](u64)} => {
|
||||
ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 }
|
||||
else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} }
|
||||
a + M~~ca
|
||||
}
|
||||
{_} => a - promote{u64, c != 0}
|
||||
}
|
||||
|
||||
while (1) {
|
||||
x:M = get_modperm_x{}
|
||||
def vrot1 = ifvec{{x} => {
|
||||
if (w256{M}) shuf{M, x, 4b2103}
|
||||
else if (any_sel) vshl{x, x, vcount{type{x}}-1}
|
||||
else shuf{[4]u32, x, 4b1032}
|
||||
}}
|
||||
xo := x<<k | vrot1{x>>(64-k)}
|
||||
# Write result word given starting bits
|
||||
def step{b, c} = output{sub_carry{c - b, c & mr}}
|
||||
def step{b, c, m} = step{b&m, c&m}
|
||||
# Fast unrolled iterations
|
||||
mask := mc4
|
||||
@for (k/4) {
|
||||
each{step{x & mask, xo & mask, .}, submasks}
|
||||
mask = advance{mask, s4}
|
||||
}
|
||||
# Single-step for tail
|
||||
mask = mask_tail
|
||||
@for (k%4) {
|
||||
step{x, xo, mask}
|
||||
mask = advance{mask, s1}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
|
||||
def k = 3
|
||||
def step{x}{o, n_over} = {
|
||||
b := x & base{2, cycle{64, n_over == iota{k}}}
|
||||
def os = (64-k) + 1 + iota{n_over}
|
||||
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
|
||||
b
|
||||
}
|
||||
while (1) fold{step{get_swap_x{}}, 0, (-iota{k}*64)%k}
|
||||
}
|
||||
|
||||
# For small odd numbers:
|
||||
# - permute each byte sending bit i to position k*i % 8
|
||||
# - replicate each byte by k, making position k*i contain bit i
|
||||
# - mask out those bits and spread over [ k*i, k*(i+1) )
|
||||
# - ...except where it crosses words; handle this overhang separately
|
||||
# If 1-byte shuffle isn't available, use 4-byte units instead
|
||||
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
||||
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
|
||||
def ww = width{W}
|
||||
def ew = if (any_sel) 8 else 32 # width of shuffle-able elements
|
||||
def ne = ww/ew; def se = 64/ew
|
||||
def lanes = ww > 128
|
||||
def dup{v} = if (lanes) merge{v,v} else v
|
||||
def fixed_loop{k} = {
|
||||
assert{wv == k}
|
||||
while (1) {
|
||||
# e.g. 01234567 to 05316427 on each byte for k==3, ew==8
|
||||
xv := get_perm_x{}
|
||||
# Overhang from previous 64-bit elements
|
||||
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
||||
def ib = ix // ew # byte index
|
||||
def io = ew*ib + k*ix%ew # where they are in xv
|
||||
def wi = split{wl, dup{tup{255, ...ib, 255, ...se+ib}}}
|
||||
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-k)}
|
||||
# Permute and mask bytes
|
||||
def step{jj, oi, ind, mask} = {
|
||||
def hk = (k-1) / 2
|
||||
def getv = if (not lanes or jj==hk) ({x}=>x)
|
||||
else sel_imm{[4]u64, ., 2*(jj>hk) + iota{4}%2}
|
||||
def selx{x, i} = sel_imm{[128/ew]ty_u{ew}, getv{x}, i}
|
||||
b := selx{xv, ind} & make{W, mask}
|
||||
r := (b<<k) - b
|
||||
def selx_nz{x, i} = { def nz = i!=255; selx{x, i * nz} & W~~make{[4]i32, -nz} }
|
||||
o := (if (ew==8) selx else selx_nz){xo, flat_table{max, oi, 255*(0<iota{se})}}
|
||||
output{r|o}
|
||||
}
|
||||
each{step,
|
||||
iota{k}, wi,
|
||||
split{ne, replicate{k, iota{ne}}},
|
||||
split{wl, each{base{2,.}, split{64, cycle{k*ww, 0==iota{k}}}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_wv <= 4 or wv < 4) {
|
||||
fixed_loop{3}
|
||||
} else if (wv < 6) {
|
||||
fixed_loop{5}
|
||||
} else {
|
||||
fixed_loop{7}
|
||||
}
|
||||
}
|
||||
|
||||
export{'si_constrep_bool', rep_const_bool{}}
|
||||
|
||||
@ -12,3 +12,5 @@ def unaligned_spaced_mask_mod{l:T} = {
|
||||
def d = cast_i{T, ctz{m}} # = 64%l
|
||||
tup{m>>d | m<<(l-d), d}
|
||||
}
|
||||
|
||||
def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh
|
||||
|
||||
Loading…
Reference in New Issue
Block a user