uCBQN/src/singeli/src/replicate.singeli

809 lines
25 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

include './base'
include './mask'
include './spaced'
def ind_types = tup{u8, u16, u32}
def dat_types = tup{...ind_types, u64}
# Indices and Replicate using plus- or max-scan
def scan_core{upd, set, scan, rp:*T, wp:W, s:(usz)} = {
def getw{j} = match (W) { {*_} => cast_i{usz,load{wp,j}}; {_} => wp }
b:usz = 1<<10
k:usz = 0; j:usz = 0; ij:=getw{j}
while (1) {
e := tern{b<s-k, k+b, s}
@for (rp over i from k to e) rp = 0
if (set) store{rp, k, cast_i{T,j}}
while (ij<e) { ++j; upd{rp, j, ij}; ij+=getw{j} }
scan{rp+k, e-k}
if (e==s) return{}
k = e
}
}
def indrep_by_sum{T, rp:*T, wp, s:(usz), js, inc} = {
def scan{ptr, len} = @for (ptr over len) js=ptr+=js
def scan{ptr:*E, len if width{T}<=32} = {
def scanfn = merge{'si_scan_pluswrap_u',fmtnat{width{T}}}
p := *ty_u{E}~~ptr
emit{void, scanfn, p, p, len, js}; js=load{ptr,len-1}
}
def upd{rp, j, ij} = store{rp, ij, load{rp,ij}+inc{j}}
scan_core{upd, 0, scan, rp, wp, s}
}
fn ind_by_scan_i32{W}(xv:*void, rp:*i32, s:usz) : void = {
xp := *W~~xv
if (hasarch{'X86_64'} and not hasarch{'SSE4.1'}) { # no min instruction
js:i32 = 0
indrep_by_sum{i32, rp, xp, s, js, {j}=>1}
} else {
scan_core{
{rp,j,ij} => store{rp,ij,cast_i{i32,j}}, 1,
{ptr,len} => emit{void, 'si_scan_max_i32', ptr,ptr,len},
rp, xp, s
}
}
}
def rep_by_scan{T, wp, xv:(*void), rv:(*void), s} = {
xp := *T~~xv; js := *xp; px := js
def inc{j} = {sx:=px; px=load{xp,j}; px-sx}
indrep_by_sum{T, *T~~rv, wp, s, js, inc}
}
fn rep_by_scan{W, T}(wp:*void, xv:*void, rv:*void, s:usz) : void = {
rep_by_scan{T, *W~~wp, xv, rv, s}
}
export_tab{'si_indices_scan_i32', each{ind_by_scan_i32, ind_types}}
export_tab{'si_replicate_scan', flat_table{rep_by_scan, ind_types, dat_types}}
# Constant replicate
def incl{a,b} = slice{iota{b+1},a}
def basic_rep = incl{2, 7}
if_inline (not (hasarch{'SSSE3'} or hasarch{'AARCH64'})) {
fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = {
assert{wv>=2}
if (width{T} < 32 and wv < 128/width{T}) {
rep_by_scan{T, cast_i{usz, wv}, x, r, cast_i{usz, wv*n}}
} else {
rw := *T~~r
@for (v in *T~~x over n) { @for (rw over wv) rw = v; rw += wv }
}
}
} else {
def has_bytesel_128 = not hasarch{'AVX2'}
# 1+˝∨`⌾⌽0=div|⌜range
def makefact{divisor, range} = {
def t = table{{a,b}=>0==b%a, divisor, range}
fold{+, 1, reverse{scan{|, reverse{t}}}}
}
def fact_size = 128
def fact_inds = slice{iota{fact_size+1},8}
def fact_tab = makefact{basic_rep, fact_inds}
factors:*u8 = fact_tab
def sdtype = [arch_defvw{}/8]i8 # shuf data type
def get_shufs{step, wv, len} = {
def i = iota{len*step}
split{step, (i - i%wv)/wv}
}
def get_shuf_data{wv, len} = get_shufs{vcount{sdtype}, wv, len} # [len] byte-selector vectors for wv/sdtype (expanded to wider types by read_shuf_vecs)
def get_shuf_data{wv} = get_shuf_data{wv, wv}
# all shuffle vectors for 𝕨≤7
def special_2 = not has_bytesel_128 # handle 2 specially on AVX2
def rcsh_vals = slice{basic_rep, special_2}
rcsh_offs:*u8 = shiftright{0, scan{+,rcsh_vals}}
rcsh_data:*i8 = join{join{each{get_shuf_data, rcsh_vals}}}
# first 4 shuffle vectors for 11≤𝕨≤61; only uses the low half of the input
def rcsh4_dom = replicate{bind{>=,64}, replicate{fact_tab==1, fact_inds}}
rcsh4_dat:*i8 = join{join{each{get_shuf_data{., 4}, rcsh4_dom}}}
rcsh4_lkup:*i8 = shiftright{0, scan{+, fold{|, table{==, rcsh4_dom, iota{64}}}}}
def read_shuf_vecs{l, ellw:(u64), shp:*V} = { # tuple of byte selectors in 1<<ellw
def double{x:X if hasarch{'AVX2'}} = {
s:=shuf{u64, x, 0,2,1,3}; s+=s
r:=each{~~{[32]i8,.}, mzip128{s, s + X**1}}
r
}
def double{x:X if has_bytesel_128} = {
s:= x+x
zip{s, s + X**1}
}
def doubles{n,tup} = slice{join{each{double,tup}}, 0, n}
def sh = each{{v}=>{r:=v}, l**V**0}
def tlen{e} = cdiv{l, e} # Length for e bytes, rounded up
def set{i} = { select{sh,i} = each{load{shp,.}, i} }
def ext{e} = {
def m = tlen{2*e}; def n = tlen{e} # m<n
if (ellw <= lb{e}) set{slice{iota{n},m}}
else slice{sh,0,n} = doubles{n,slice{sh,0,m}}
}
set{iota{tlen{8}}}; ext{4}; ext{2}; ext{1}
sh
}
def rep_const_shuffle{wv, onreps, xv:*V=[step]T, rv:*V, n:(u64)} = { # onreps{inputVector, {nextOutputVector} => ...}
nv := n / step
j:u64 = 0
def write{v} = { store{rv, j, v}; ++j }
@for (xv over nv) onreps{xv, write}
if (nv*step < n) {
nr := n * wv
e := nr / step
s := V**0
def end = makelabel{}
onreps{load{xv,nv}, {v} => {
s = v
if (j == e) goto{end}
write{s}
}}
setlabel{end}
q := nr & (step-1)
if (q!=0) store_blended_hom{rv+e, mask_of_first{V, q}, s}
}
}
if_inline (hasarch{'AVX2'}) {
def rep_iter_from_sh{sh}{x, gen} = {
def l = length{sh}
def h = l>>1
def fs{v, s} = gen{shuf{[16]i8, v, s}}
a := shuf{u64, x, 0,1,0,1}; each{fs{a,.}, slice{sh,0,h}}
if (l%2) fs{x, select{sh, h}}
b := shuf{u64, x, 2,3,2,3}; each{fs{b,.}, slice{sh,-h}}
}
def get_rep_iter{V, wv==2}{x, gen} = {
def s = shuf{u64, x, 0,2,1,3}
each{{q}=>gen{V~~q}, mzip128{s, s}}
}
def get_rep_iter{V==[4]u64, wv} = {
def step = 4
def sh = get_shufs{step, wv, wv}
{x, gen} => each{{s}=>gen{shuf{V, x, s}}, sh}
}
def rep_const_shuffle{wv, xv:*V, rv:*V, n:(u64)} = rep_const_shuffle{wv, get_rep_iter{V, wv}, xv, rv, n}
} else { # has_bytesel_128
def rep_iter_from_sh{sh}{x, gen} = {
each{{s} => gen{shuf{[16]u8, x, s}}, sh}
}
def rep_const_shuffle{wv==2, xv0:*V=[_]T, rv0:*V, n:(u64)} = {
def E = ty_u{T}
rv:= *E~~rv0
@for (x in *E~~xv0 over i to n) { # autovectorized well enough, probably
store{rv, i*2, x}
store{rv, i*2+1, x}
}
}
}
fn rep_const_shuffle_partial4(wv:u64, ellw:u64, x:*i8, r:*i8, n:u64) : void = {
def h = 4
def V = sdtype
def sh = read_shuf_vecs{h, ellw, *V~~rcsh4_dat + h*load{rcsh4_lkup,wv}}
def [step]_ = V # Bytes written
def wvb = wv << ellw
def hs = (h*step) / wvb # Actual step size in argument elements
def shufbase{i if hasarch{'AVX2'}} = shuf{u64, load{*V~~(x+i)}, 0,1,0,1}
def shufbase{i if has_bytesel_128} = load{*V~~(x+i)}
def shufrun{a, s} = shuf{[16]i8, a, s} # happens to be the same across AVX2 & NEON
i:u64 = 0
re := r + n*wvb - h*step
while (r <= re) {
a := shufbase{i}
@unroll (j to h) store{*V~~r, j, shufrun{a, select{sh,j}}}
i += hs << ellw
r += hs*wvb
}
re+= (h-1)*step
a:= shufbase{i}
s:= V**0
def end = makelabel{}
@unroll (j to h) {
s = shufrun{a, select{sh,j}}
if (r > re) goto{end}
store{*V~~r, 0, s}
r+= step
}
setlabel{end}
q := (re+step) - r
if (q!=0) store_blended_hom{r, mask_of_first{V, q}, s}
}
fn rcsh_sub{wv, V}(ellw:u64, x:*i8, r:*i8, n:u64, sh:*V) : void = {
def st = read_shuf_vecs{wv, ellw, sh}
rep_const_shuffle{wv, rep_iter_from_sh{st}, *V~~x, *V~~r, n}
}
fn rep_const_shuffle_any(wv:u64, ellw:u64, x:*i8, r:*i8, n:u64) : void = {
if (wv > select{rcsh_vals,-1}) {
return{rep_const_shuffle_partial4(wv, ellw, x, r, n)}
}
n <<= ellw
ri := wv - select{rcsh_vals,0}
sh := *sdtype~~rcsh_data + load{rcsh_offs,ri}
def try{k} = { if (wv==k) rcsh_sub{k, sdtype}(ellw, x, r, n, sh) }
each{try, rcsh_vals}
}
def rep_const_broadcast{T, kv, loop, wv:(u64), x:*T, r:*T, n:(u64)} = {
assert{kv > 0}
def V = [arch_defvw{}/width{T}]T
@for (x over n) {
v := V**x
@loop (j to kv) store{*V~~r, j, v}
r += wv
store{*V~~r, -1, v}
}
}
fn rep_const_broadcast{T, kv }(wv:u64, x:*T, r:*T, n:u64) : void = rep_const_broadcast{T, kv, unroll, wv, x, r, n}
fn rep_const_broadcast{T}(kv:u64, wv:u64, x:*T, r:*T, n:u64) : void = rep_const_broadcast{T, kv, for , wv, x, r, n}
fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = {
assert{wv>=2}
if (wv>=8 and wv<=fact_size) {
fa := promote{u64, load{factors,wv-8}}
if (fa > 1) {
fi := wv / fa
def t = *void~~(*T~~r + (promote{u64,wv}-fi)*n)
rep_const{T}(fi,x,t,n)
rep_const{T}(fa,t,r,fi*n)
return{}
}
}
def wT = width{T}
def vn = arch_defvw{}/wT
def V = [vn]T
def max_shuffle = 2*vn
if (wv <= max_shuffle) {
def specialize{k} = {
if (wv==k) return{rep_const_shuffle{k, *V~~x, *V~~r, n}}
}
if (special_2) specialize{2}
rep_const_shuffle_any(wv, lb{wT/8}, *i8~~x, *i8~~r, n)
} else {
kv := wv / vn
@unroll (k from (max_shuffle/vn) to 4) {
if (kv == k) return{rep_const_broadcast{T, k}(wv, *T~~x, *T~~r, n)}
}
rep_const_broadcast{T}(kv, wv, *T~~x, *T~~r, n)
}
}
}
export_tab{'si_constrep', each{rep_const, dat_types}}
# Constant replicate on boolean
def any_sel = hasarch{'SSSE3'} or hasarch{'AARCH64'}
if_inline (hasarch{'AARCH64'}) {
def __shl{a:V=[_]T, b:U if not vect{U}} = a << V**cast_i{T,b}
def __shr{a:V=[_]T, b:U if not vect{U}} = a << V**cast_i{T,-b}
}
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = {
assert{wv >= 2}
nw := cdiv{rlen, 64}
if (wv&1 == 0) {
p := ctz{wv | 8} # Power of two for second replicate
wf := wv>>p
if (wf == 1) {
rep_const_bool_div8{wv, x, r, nw}
} else if (any_sel and p == 3 and wv <= 8*select{basic_rep, -1}) {
# (higher wv double-factors, which doesn't work with in-place pointers)
tlen := rlen / wf
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
rep_const_bool{}(8, x, t, tlen)
rep_const{select{dat_types,0}}(promote{u64,wf}, *void~~t, *void~~r, promote{u64,tlen/8})
} else {
tlen := rlen >> p
wq := usz~~1 << p
if (p == 1 and (not any_sel or wv>=24)) {
# Expanding odd second is faster
tlen = rlen / wf
t:=wf; wf=wq; wq=t
}
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
rep_const_bool{}(wf, x, t, tlen)
rep_const_bool{}(wq, t, r, rlen)
}
} else if (not hasarch{'AVX2'} and wv >= 64) {
# Identical to the non-Singeli case for large wv
# but we assume n==0 and drop the loop
assert{wv < 128}
i:usz = 0
ri:usz = 0; j:usz = 0 # Bit index in r; ri/64
c:u64 = 0
while (j < nw-1) { # e<=j+2 below, so e-1<nw
v := -(1 & load{x, i/64}>>(i%64)); ++i
r0:= c ^ ((v^c) << (ri%64))
c = v
ri+= wv; e:= ri/64
store{r, e-1, c}
store{r, j, r0}
j = e
}
if (ri%64 != 0) store{r, j, c}
} else {
rep_const_bool_odd{wv, x, r, nw}
}
}
def rep_const_bool_div8{wv, x, r, nw} = { # wv in 2,4,8
def run{k} = {
# 2 -> 64w0x33, 12 -> 64w0x000f, etc.
def getm{sh} = base{2, iota{64}&sh == 0}
def osh{v, s} = v | v<<s
def expand = match (k) {
{2} => fold{
{v, sh} => osh{v, sh} & getm{sh},
., 1 << reverse{iota{5}}
}
{4} => fold{
{v, sh} => osh{osh{v, sh}, 2*sh} & getm{sh},
., tup{12, 3}
}
{8} => {
def mult = base{1<<7, 8**1}
{x} => (x | ((x&~1) * mult)) & 64w0x01
}
}
@for (xt in *ty_u{64/k}~~x, r over nw) {
def v = expand{promote{u64, xt}}
r = v<<k - v
}
}
def cases{k} = if (wv==k) run{k} else if (k<8) cases{2*k}
cases{2}
}
def get_boolvec_writer{V, r, nw} = {
def vwords = width{V}/64
rv := *V~~r
re := rv + nw / vwords
rs := rv + cdiv{nw, vwords}
# Avoid reading/processing an extra word past the end of input
def done = makelabel{}
def check_done{} = if (rv == rs) goto{done}
# If the last result is partial, jump to flush to write it
last_res:V = V**0
def l_flush = makelabel{}
def output{v:(V)} = {
last_res = v
if (rv == re) goto{l_flush}
store{rv, 0, v}; ++rv
}
def flush{} = {
setlabel{l_flush}
q := nw & (vwords-1)
if (q != 0) store_blended_hom{rv, V~~mask_of_first{re_el{u64,V}, q}, last_res}
setlabel{done}
}
tup{output, check_done, flush}
}
def get_boolvec_writer{T==u64, r:*T, nw} = {
def done = makelabel{}
j:usz = 0
def output{rw} = {
store{r, j, rw}
++j; if (j==nw) goto{done}
}
def flush{} = setlabel{done}
tup{output, {}=>{}, flush}
}
def rep_const_bool_div8{wv, x, r, nw if has_simd} = {
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16
def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}
# Not the same as a 1-byte shift but extra bits are always masked away
def H = re_el{u16, V}
def __shl{x:(V), a} = V~~(H~~x << a)
def __shr{x:(V), a} = V~~(H~~x >> a)
def {output, check_done, flush} = get_boolvec_writer{V, r, nw}
def run24{x, get_halves} = {
i:usz = 0; while (1) { check_done{}
xv := load{*V~~(x+i)}; ++i
def getr = zip128{...get_halves{xv}, .}
output{V~~getr{0}}
output{V~~getr{1}}
}
}
def run24{x, pre, exh, sh} = run24{x,
{xv} => { p := pre{xv}; tup{exh{p}, exh{p>>sh}} }
}
def run8{rep_bytes} = {
i:usz = 0; while (1) { check_done{}
xv := V**load{*[16]u8 ~~ (*ty_u{vl}~~x + i)}; ++i
xe := rep_bytes{xv}
output{(xe & mkV{1 << (iV % 8)}) > V**0}
}
}
if (not any_sel) {
if (wv == 2) {
run24{*V~~x, {xv} => {
def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<<l) }
a := swap{swap{xv, 2, 0x0c}, 1, 0x22}
m := V**0x55
h0 := a & m; o0 := h0 | h0<<1
h1 := a &~ m; o1 := h1 | h1>>1
tup{o0, o1}
}}
} else if (wv == 4) {
def pre{xv} = {
a := xv ^ ((xv & V**0x55) << 1)
e := zip128{a, a>>4, 0}
}
def out{h} = (-(h & V**1)) ^ (-(h<<3 & V**0x10))
run24{*u64~~x, pre, out, 2}
} else { # wv == 8
def z{x} = zip{x,x,0}
run8{{x} => z{z{z{x}}}}
}
} else { # any_sel
def selH = shuf{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
def id{xv} = xv
if (wv == 2) {
def init = if (avx2) shuf{u64, ., 0,2,1,3} else id
# Expander for half byte
def tabr = makeTab{tr_iota{2*iota{4}} * 2b11}
m4 := V**0xf
run24{*V~~x, init, {x} => tabr{x & m4}, 4}
} else if (wv == 4) {
# Unzip 32-bit elements (result lanes) across AVX2 lanes
def pre = if (avx2) shuf{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id
def init{xv} = { u:=pre{xv}; zip128{u,u,0} }
# Expander for two bits in either bottom or next-to-bottom position
def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111}
m2 := mkV{2b11 << (2*(iV%2))}
def exh{x} = re_el{u16, V}~~tabr{x & m2}
run24{*(if (avx2) [2]u64 else u64)~~x, init, exh, 4}
} else { # wv == 8
run8{selH{., mkV{iV // 8}}}
}
}
flush{}
}
# Data for the permutation that sends bit i to k*i % width{T}
def modperm_dat{T, k} = {
def w = width{T}
def i = iota{lb{w}}
def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1<<i, i}}))
def E = match (T) { {[_]E} => E; {_} => T }
each{{x} => E~~base{2,x}, split{width{E}, bits}}
}
# Permutation step evaluators
# shift takes a top-half mask; others take both-halves
def modperm_shift_step{x, l, m} = {
def d = (x ^ x<<l) & m
x ^ (d | d>>l)
}
def modperm_rot_step{x:T, l=(width{T}/2), m} = {
(x &~ m) | ((x<<l | x>>l) & m)
}
def modperm_shuf_step{x:V=[_]T, l, m if l%8==0} = {
(x &~ m) | (swap_elts{x, l/8} & m)
}
# Reverse each pair of elements
def swap_elts{x:V=[k]_, el_bytes} = {
if (hasarch{'AARCH64'}) {
if (el_bytes == 8) vec_merge_shift_right{x, x, k/2}
else V~~reverse_units{2, re_el{ty_u{8*el_bytes}, x}}
} else if (any_sel or (hasarch{'SSE2'} and el_bytes >= 2)) {
def wd = __min{el_bytes, 8}
def l = el_bytes/wd
def i = iota{l}
shuf{ty_u{wd*8}, x, merge{i+l, i}}
} else {
def sh = el_bytes*8
xe := re_el{ty_u{2*sh}, V} ~~ x
V~~(xe<<sh | xe>>sh)
}
}
# Extract swap functions from modperm_dat
def extract_modperm_mask{full:W, lane, l} = {
def f = l == 128 # Use full width
def w = if (l<32 or not hasarch{'SSE2'}) 8
else if (f) 64 else 32 # Shuffle element width
def n = (if (f) 256 else 128) / w # and number of elements
def o = l / w # Extract [o,2*o)
def i = o + iota{n}%o
shuf{ty_u{w}, if (f) full else lane, i}
}
def proc_mod_dat{swap_data:W} = {
def ww = width{W}; def avx2 = ww==256
def on_len_range{get_swap, lo, hi} = {
def lens = reverse{1 << slice{iota{lb{hi}}, lb{lo}}}
fold{{x,s}=>s{x}, ., each{get_swap, lens}}
}
swap_lane := if (avx2) shuf{swap_data, 0,1,0,1} else swap_data
# Transform width-w units with shifts only
def get_shiftperm{data, w} = {
dat := data
bot:W = W**base{2, cycle{64, replicate{w, tup{1,0}}}}
def gsw{l} = {
bot ^= bot << l # Low l bits out of every 2*l
sm := dat &~ bot
dat &= bot; dat |= dat<<l
if (l<32) modperm_shift_step{., l, sm}
else modperm_rot_step {., l, sm|sm>>l}
}
on_len_range{gsw, 2, w}
}
# Within-byte transformation
def get_byteperm{} = {
def V = re_el{u8, W}
def sw_bytes = shuf{u8, swap_lane, 16**0}
m4 := W~~V**0xf
t0 := fold{{v,a}=>modperm_shift_step{v,...a}, W~~make{V,iota{vcount{V}}%16}, tup{
tup{4, sw_bytes &~ m4},
tup{2, ({v} => v | v<<4){sw_bytes & W~~V**0xc}}
}}
t4 := (t0&m4)<<4 | (t0&~m4)>>4
def selI{v,i} = shuf{[16]u8, v, V~~i}
{xv} => selI{t0, xv & m4} | selI{t4, (xv>>4) & m4}
}
def {partwidth, partperm} = {
def sh{w, d} = tup{w, get_shiftperm{d, w}}
if (ww==64) sh{64, swap_data}
else if (not any_sel) sh{32, shuf{u32, swap_data, 0,0,0,0}}
else tup{8, get_byteperm{}}
}
# Fill in higher steps
def get_mod_permuter{width} = {
def get_swap{l} = {
def mask = extract_modperm_mask{swap_data, swap_lane, l}
modperm_shuf_step{., l, mask}
}
def swap_x = on_len_range{get_swap, partwidth, width}
{x} => partperm{swap_x{x}}
}
def get_mod_permuter{} = get_mod_permuter{ww}
tup{partperm, get_mod_permuter}
}
def MP = if (has_simd) [if (hasarch{'AVX2'}) 4 else 2]u64 else u64
swtab:*u64 = join{each{modperm_dat{MP, .}, 1+2*iota{32}}}
def rep_const_bool_odd{k, x, r, nw} = {
def W = MP
def {output, check_done, flush} = get_boolvec_writer{W, r, nw}
xp := *W~~x
def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} }
# Modular permutation: small-k cases may use a limited permutation
# on bytes or 32-bit ints; general case uses the whole thing
swap_data := load{*W~~swtab, (k%64)>>1}
def {partperm, get_full_permute} = proc_mod_dat{swap_data}
def sp_max = if (any_sel) 8 else 4
if (k < sp_max) {
rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output}
} else if (not hasarch{'AVX2'} or k < 64) {
def get_swap_x = getter{get_full_permute{}}
rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}}
} else {
def get_swap_x = getter{get_full_permute{64}}
rep_const_bool_odd_loose_mask{W, k, get_swap_x, output}
}
flush{}
}
# General-case loop for odd replication factors
def rep_const_bool_odd_mask4{
M, # read/write type
k, # replication factor
get_modperm_x, # permuted input
output, n, # output, number of writes
} = {
def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
def scal = ifvec{{v} => M**v}
assert{k < 64}
# Fundamental operation: shifts act as order-k cyclic group on masks
def advance{m, sh} = advance_spaced_mask{k, m, sh}
# Double a cumulative mask, shift combination
# If s advances l iterations, mc combines iterations iota{l}
def double_gen{comb}{{mc, s}} = {
def mn = comb{mc, advance{mc,s}}
def ss = s+s
tup{mn, ss - (k &- (ss>k))}
}
# Starting word mask and single-word shift
{mask, mask_sh} := unaligned_spaced_mask_mod{k}
# Mask and shift for one iteration
def fillmask{T} = match (T) {
{(u64)} => tup{mask, mask_sh}
{[2]E} => double_gen{make{T,...}}{fillmask{E}}
{[4]E} => double_gen{pair}{fillmask{[2]E}}
}
{sm0, s1} := fillmask{M}
# Combined mask for 4 iterations, and shift to advance 4
def double = double_gen{|}
{mc4, s4} := double{double{tup{sm0, s1}}}
# Submasks pick one mask out of a combination of 4
def or_adv{m, s} = { m |= advance{m,s} }
@for (__min{k/4 - 1, n/4}) or_adv{sm0,s4}
submasks := scan{advance, tup{sm0, ...3**s1}}
mask_tail := advance{sm0, s4} &~ sm0
# Carry: shifting and word-crossing is done on the initial permuted x
# No need to carry across input words since they align with output words
# First bit of each word in xo below is wrong, but it doesn't matter!
mr := scal{tail{u64, k}} # Mask out carry bit before output
def sub_carry{a, c} = match (M) {
{[l](u64)} => {
ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 }
else { cm := [2*l]i32~~c; cm != shuf{cm, 1,0} }
a + M~~ca
}
{_} => a - promote{u64, c != 0}
}
while (1) {
x:M = get_modperm_x{}
def vrot1 = ifvec{{x} => {
if (w256{M}) shuf{x, 3,0,1,2}
else if (any_sel) vec_merge_shift_right{x, x, vcount{type{x}}-1}
else shuf{x, 1,0}
}}
xo := x<<k | vrot1{x>>(64-k)}
# Write result word given starting bits
def step{b, c} = output{sub_carry{c - b, c & mr}}
def step{b, c, m} = step{b&m, c&m}
# Fast unrolled iterations
mask := mc4
@for (k/4) {
each{step{x & mask, xo & mask, .}, submasks}
mask = advance{mask, s4}
}
# Single-step for tail
mask = mask_tail
@for (k%4) {
step{x, xo, mask}
mask = advance{mask, s1}
}
}
}
def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = {
def k = 3
def step{x}{o, n_over} = {
b := x & base{2, cycle{64, n_over == iota{k}}}
def os = (64-k) + 1 + iota{n_over}
output{fold{|, (b<<k) - b, each{>>{o,.}, os}}}
b
}
while (1) fold{step{get_swap_x{}}, 0, (-iota{k}*64)%k}
}
# For small odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
# - mask out those bits and spread over [ k*i, k*(i+1) )
# - ...except where it crosses words; handle this overhang separately
# If 1-byte shuffle isn't available, use 4-byte units instead
def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v }
def ww = width{W}
def ew = if (any_sel) 8 else 32 # width of shuffle-able elements
def ne = ww/ew; def se = 64/ew
def lanes = ww > 128
def dup{v} = if (lanes) merge{v,v} else v
def fixed_loop{k} = {
assert{wv == k}
while (1) {
# e.g. 01234567 to 03614725 on each byte for k==3, ew==8
xv := get_perm_x{}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
def ib = ix // ew # byte index
def io = ew*ib + k*ix%ew # where they are in xv
def wi = split{wl, dup{tup{-1, ...ib, -1, ...se+ib}}}
xo := ov_bytes{(xv & W**fold{|, 1<<io}) >> (ew-k)}
# Permute and mask bytes
def step{jj, oi, ind, mask} = {
def hk = (k-1) / 2
def getv = if (not lanes or jj==hk) ({x}=>x)
else shuf{[4]u64, ., 2*(jj>hk) + iota{4}%2}
def selx{x, i} = shuf{[128/ew]ty_u{ew}, getv{x}, i}
b := selx{xv, ind} & make{W, mask}
r := (b<<k) - b
def selx_nz{x, i} = { def nz = i != -1; selx{x, i * nz} & W~~make{[4]i32, -nz} }
o := (if (ew==8) selx else selx_nz){xo, flat_table{__min, oi, 255-256*(0<iota{se})}}
output{r|o}
}
each{step,
iota{k}, wi,
split{ne, replicate{k, dup{iota{ne/(1+lanes)}}}},
split{wl, each{base{2,.}, split{64, cycle{k*ww, 0==iota{k}}}}}
}
}
}
if (max_wv <= 4 or wv < 4) {
fixed_loop{3}
} else if (wv < 6) {
fixed_loop{5}
} else {
fixed_loop{7}
}
}
# Odd factors larger than 64
# AVX2-only because scalar should be about as good otherwise
def rep_const_bool_odd_loose_mask{V=[vl==4](u64), k, get_modperm_x, output if hasarch{'AVX2'}} = {
assert{k > 64}
# Distance from end to previous row boundary (-k <= q < 0)
q := -make{V, 64*(1+iota{vl})}
def q_mod{} = { d:=q+V**k; q = blend_top{q,d, d} }
o:usz = width{V}; while (o>k) { o-=k; q_mod{} }
km:= k%64
i:usz = 0; iv:usz = 0 # Words and vectors completed
def step{perm, diff} = {
# Mod-64 mask with 1 bit per word
m:= V**1 << (q & V**63)
# Indicator of which bits are actual boundaries
def S = ty_s{V}; a:= S**(-65) < S~~q
q-= V**o; q_mod{}
# Set to bit from perm, but below a&m xor with diff to get previous
base:= (m & perm) == m
md:= (a & m) & diff
output{base ^ (md + (md==m))}
}
{xp, xd, perm, diff}:= 4**(V**0)
while (1) {
# get_modperm_x permutes each 64-bit word
# Each iteration of this loop handles one permuted word
def first = shuf{., 4**0}
if (i%4 == 0) {
xp = get_modperm_x{}
# Shift by 1, or k%64 in mod-space
# Then the low bit of each word has to be moved to the next
# As before, first bit is wrong but unused
xl:= xp>>(64-km)
xo:= (xp<<km | (xl &~ V**1)) | (shuf{xl, 3,0,1,2} & V**1)
xd = xo ^ xp
perm = first{xp}; diff = first{xd}
} else {
def upd{xq, q} = {
xq = shuf{xq, 1,2,3,0} # Next word
qs:= q; q = first{xq}
blend_hom{q, qs, iota{V} < V**(i%4)}
}
step{upd{xp, perm}, upd{xd, diff}} # Do boundary between iterations
++iv
}
i+= k
ip:= iv; iv = i/4
@for (iv - ip) step{perm, diff}
}
}
export{'si_constrep_bool', rep_const_bool{}}