diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli
index 6eec5b65..351a3644 100644
--- a/src/singeli/src/replicate.singeli
+++ b/src/singeli/src/replicate.singeli
@@ -300,7 +300,8 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
} else {
tlen := rlen>>p
wq := usz~~1<
=52)) { # Expanding odd second is faster
+ if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (not hasarch{'AVX2'} and p == 1 and wv>=24)) {
+ # Expanding odd second is faster
tlen = rlen / wf
t:=wf; wf=wq; wq=t
}
@@ -309,14 +310,12 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
rep_const_bool{}(wq, t, r, rlen)
}
} else {
- m:u64 = spaced_mask_of{wv}
- d := cast_i{usz, popc{m}} # == 64/wv
- rep_const_bool_generic_odd{wv, x, r, nw, m, d}
+ rep_const_bool_odd{wv, x, r, nw}
}
1
}
-def rep_const_bool_div8{wv, x, r, nw} = {
+def rep_const_bool_div8{wv, x, r, nw} = { # wv in 2,4,8
def run{k} = {
# 2 -> 64w0x33, 12 -> 64w0x000f, etc.
def getm{sh} = base{2, iota{64}&sh == 0}
@@ -344,133 +343,8 @@ def rep_const_bool_div8{wv, x, r, nw} = {
cases{2}
}
-def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
- # Every-k-bits mask
- mask_sh := cast_i{usz, ctz{m}} # == 64%k
- mask := m<<(k-mask_sh) | 1
-
- # Transform sending bit i to k*i % 64 by pairwise swaps
- # Swap data goes in a pre-computed table
- def swdat{lw, k} = {
- def i = iota{lw}
- def bits = (k*iota{1<> merge{0, replicate{1<>1}
- swsel:u64 = ~u64~~0
- def gsw{l} = {
- swsel ^= swsel << l # Low l bits out of every 2*l
- sm := swap_data &~ swsel
- swap_data &= swsel; swap_data |= swap_data<>l)
- }
- def swap_step{l==32, m} = { # Use rotate
- mm := m | m>>l
- x = (x &~ mm) | ((x<>l) & mm)
- }
- each{swap_step, swap_lens, swap_masks}
- x
- }
-
- # Output
- j:usz = 0
- def output{rw} = {
- store{rp, j, rw}
- ++j; if (j==nw) return{1}
- }
- o:u64 = 0 # carry
- # Dedicated loop for 3, shared for other factors
- if (k == 3) {
- while (1) {
- x := get_swap_x{}
- @unroll (jj to 3) {
- b := x & mask
- mask = (mask << (3 - 64%3)) | (mask >> (64%3))
- def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}}
- output{fold{|, (b<<3) - b, each{>>{o,.}, os}}}
- o = b
- }
- }
- } else {
- # Fundamental operation: shifts act as order-k cyclic group on masks
- def advance{m, sh} = m<<(k-sh) | m>>sh
- sm0 := mask # starting mask
- s1 := mask_sh # single iteration shift
- # Get cumulative mask for 4 iterations, and shift to advance 4
- def double{{mc, s}} = {
- def ss = s+s
- tup{mc|advance{mc,s}, ss - (k &- (ss>k))}
- }
- {mc4, s4} := double{double{tup{sm0, s1}}}
- # Submasks pick one mask out of a combination of 4
- def or_adv{m, s} = { m |= advance{m,s} }
- @for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
- submasks := scan{advance, tup{sm0, ...3**s1}}
- mask_tail := advance{sm0, s4} &~ sm0
- # Mask out carry bit
- mr := u64~~1<>(64-k); o=xo&1; xo=(xo&~1)|os
- mask = mc4
- # Write result word given starting bits
- def step{b, c} = output{c - b - promote{u64, c&mr != 0}}
- def step{b, c, m} = step{b&m, c&m}
- # Fast unrolled iterations
- @for (k/4) {
- each{step{x & mask, xo & mask, .}, submasks}
- mask = advance{mask, s4}
- }
- # Single-step for tail
- mask = mask_tail
- @for (k%4) {
- step{x, xo, mask}
- mask = advance{mask, s1}
- }
- }
- }
-}
-
-fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
- if (wv > 32) return{0}
- if (wv&1 == 0) {
- p := ctz{wv | 8} # Power of two for second replicate
- if (wv>>p == 1) {
- rep_const_bool_div8{wv, x, r, rlen}
- } else {
- tlen := rlen>>p
- t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
- rep_const_bool{}(wv >>p, x, t, tlen)
- rep_const_bool{}(usz~~1<floor{a/b}) infix left 40
def avx2 = hasarch{'AVX2'}
def vl = if (avx2) 32 else 16
@@ -498,7 +372,7 @@ def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8
def selH = sel{[16]u8, ., .}
def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .}
def id{xv} = xv
- def {output, flush} = get_boolvec_writer{V, r, rlen}
+ def {output, flush} = get_boolvec_writer{V, r, nw}
def run24{x, proc_xv, exh} = {
i:usz = 0; while (1) {
@@ -535,29 +409,189 @@ def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8
flush{}
}
+# Data for the permutation that sends bit i to k*i % width{T}
+def modperm_dat{T, k} = {
+ def w = width{T}
+ def lw= lb{w}
+ def i = iota{lw}
+ def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1< make{T, each{base{2,.}, split{width{E}, bits}}}
+ {_} => T~~base{2, bits}
+ }
+}
+def modperm_step{x, l, m} = {
+ def d = (x ^ x<>l)
+}
+def modperm_step{x:T, l=(width{T}/2), m} = {
+ def mm = m | m>>l
+ (x &~ mm) | ((x<>l) & mm) # rotate
+}
+def modperm_step{x:T=[_](u8), l, m} = {
+ def W = re_el{u64, T}
+ T~~modperm_step{W~~x, l, W~~m}
+}
+def swap_elts{x:V, el_bytes if w128{V}} = {
+ def n = 16; def I = [n]u8
+ def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) }
+ if (el_bytes >= 4) shuf{[4]u32, x, base{4, swi{4, el_bytes/4}}}
+ else sel{I, x, make{I, swi{16, el_bytes}}}
+}
+def modperm_step{x, l, m:V=[_]T if l%8==0} = {
+ (x & m) | (swap_elts{x, l/8} &~ m)
+}
+def modperm_get_byteperm{sw_bytes:V=([16]u8) if hasarch{'SSSE3'}} = {
+ def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s}
+ m4 := V**0xf
+ t0 := fold{{v,a}=>modperm_step{v,...a}, iota{V}, tup{
+ tup{4, sw_bytes &~ m4},
+ tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}}
+ }}
+ t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4}
+ {xv} => sel{V, t0, xv & m4} | sel{V, t4, shW{>>, xv, 4} & m4}
+}
+def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}}
+
+def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh
+
+# General-case loop for odd replication factors
+def rep_const_bool_odd_mask4{
+ M, # read/write type
+ k, # replication factor
+ get_modperm_x, # permuted input
+ output, n, # output, number of writes
+ mask:(u64), # starting mask
+ mask_sh # single iteration shift
+} = {
+ def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) }
+ def scal = ifvec{{v} => M**v}
+
+ # Fundamental operation: shifts act as order-k cyclic group on masks
+ def advance{m, sh} = advance_spaced_mask{k, m, sh}
+ # Double a cumulative mask, shift combination
+ # If s advances l iterations, mc combines iterations iota{l}
+ def double_gen{comb}{{mc, s}} = {
+ def mn = comb{mc, advance{mc,s}}
+ def ss = s+s
+ tup{mn, ss - (k &- (ss>k))}
+ }
+ # Mask and shift for one iteration
+ {sm0, s1} := ifvec{double_gen{make{M,...}}}{tup{mask, mask_sh}}
+ # Combined mask for 4 iterations, and shift to advance 4
+ def double = double_gen{|}
+ {mc4, s4} := double{double{tup{sm0, s1}}}
+ # Submasks pick one mask out of a combination of 4
+ def or_adv{m, s} = { m |= advance{m,s} }
+ @for (min{k/4 - 1, n/4}) or_adv{sm0,s4}
+ submasks := scan{advance, tup{sm0, ...3**s1}}
+ mask_tail := advance{sm0, s4} &~ sm0
+
+ # Carry: shifting and word-crossing is done on the initial permuted x
+ o:M = scal{0} # Carry for x
+ mr := scal{u64~~1< {
+ ca := if (hasarch{'SSE4.2'}) { def S = [l]i64; S~~c > S**0 }
+ else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} }
+ a + M~~ca
+ }
+ {_} => a - promote{u64, c != 0}
+ }
+
+ while (1) {
+ x:M = get_modperm_x{}
+ def vrot1 = ifvec{{x} => vshl{x, x, vcount{type{x}}-1}}
+ k1:M = scal{1}
+ os:=o; xo:=x<>(64-k)}; o=xo&k1; xo=(xo&~k1)|os
+ # Write result word given starting bits
+ def step{b, c} = output{sub_carry{c - b, c & mr}}
+ def step{b, c, m} = step{b&m, c&m}
+ # Fast unrolled iterations
+ mask := mc4
+ @for (k/4) {
+ each{step{x & mask, xo & mask, .}, submasks}
+ mask = advance{mask, s4}
+ }
+ # Single-step for tail
+ mask = mask_tail
+ @for (k%4) {
+ step{x, xo, mask}
+ mask = advance{mask, s1}
+ }
+ }
+}
+
+def rep_const_bool_odd{k, xp, rp, nw} = {
+ # Every-k-bits mask
+ m:u64 = spaced_mask_of{k}
+ d := cast_i{usz, popc{m}} # == 64/k
+ mask_sh := cast_i{usz, ctz{m}} # == 64%k
+ mask := m<<(k-mask_sh) | 1
+
+ # Transform sending bit i to k*i % 64 by pairwise swaps
+ # Swap data goes in a pre-computed table
+ swtab:*u64 = each{modperm_dat{u64, .}, 1+2*iota{32}}
+ def swap_lens = reverse{2 << iota{5}}
+ swap_data := load{swtab, k>>1}
+ swsel:u64 = ~u64~~0
+ def gsw{l} = {
+ swsel ^= swsel << l # Low l bits out of every 2*l
+ sm := swap_data &~ swsel
+ swap_data &= swsel; swap_data |= swap_data<f{x}, n**init}
+ def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask}
+ def step{x}{m, n_over} = {
+ b := x & m
+ def os = (64-k) + 1 + iota{n_over}
+ output{fold{|, (b<>{o,.}, os}}}
+ o = b
+ }
+ while (1) each{step{get_swap_x{}}, masks, (-iota{k}*64)%k}
+ }
+ unrolled_iter{3}
+ } else {
+ rep_const_bool_odd_mask4{u64, k, get_swap_x, output, nw, mask, mask_sh}
+ }
+}
+
# For odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
# - mask out those bits and spread over [ k*i, k*(i+1) )
# - ...except where it crosses words; handle this overhang separately
-def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
+def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15
oper // ({a,b}=>floor{a/b}) infix left 40
def vl = 16; def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
def W = [2]u64
- def {output, flush} = get_boolvec_writer{V, r, rlen}
+ def {output, flush} = get_boolvec_writer{V, r, nw}
+ # Swap data goes in a pre-computed table
+ swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}}
+ swap_data := load{swtab, wv>>1}
# Within-byte transformation
- def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}}
- ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
- {t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
- m4 := V**0xf
- def perm_x{xv} = {
- selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
- }
+ def perm_x = modperm_get_byteperm{selV{swap_data, V**0}}
- # Cases are 3; 5 7; 9 11 13 15
if (wv < 4) {
# 3: dedicated loop
i:usz = 0; while (1) {
@@ -629,47 +663,16 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
o := xo & mkV{255 * (iV%8 == 0)} # overhang
output{rv | o}
}
- } else { # wv < 32
- # 9 to 31: extend k*i % 8 transform to k*i % 128 by pairwise swaps
- # Swap data goes in a pre-computed table
- def swdat{k} = {
- def bits = (k*iota{128} >> merge{0, replicate{1<>1)-4}
- swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens}
- # Every-k-bits mask, same as before
- {m, d} := unaligned_spaced_mask_mod{wv}
- mask := make{W, m, m>>d|m<<(wv-d)}
- mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
- # Mask out carry bit
- mr := [4]u32~~W**(u64~~1< selV{~swap_data, mkV{l+iV%l}}, swap_lens}
+ def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks}
i:usz = 0
- while (1) {
- # Load xv, send bit i to position wv*i % 128
- xv = load{*V~~x, i}; ++i
- def swap_step{l, m} = {
- xv = (xv & m) | (selV{xv, mkV{iV + (l-2*(iV&l))}} &~ m)
- }
- each{swap_step, swap_lens, swap_masks}
- xv = perm_x{xv}
- xw := W~~xv
- def vrot1{x} = vshl{x, x, vcount{type{x}}-1}
- w1 := W**1
- os:=o; xo:=xw<>(64-wv)}; o=xo&w1; xo=(xo&~w1)|os
- # Write wv vectors based on that
- @for (wv) {
- b := xw & mask
- # Handle overhang here; won't fit in a single vector
- c := xo & mask; cu := [4]u32~~c
- output{V~~(W~~(cu + (mr&cu > [4]u32**0)) - b)}
- mask = (mask << (wv - mask_sh)) | (mask >> mask_sh)
- }
- }
+ def get_swap_x{} = { xv := perm_x{swap_x{load{*V~~x, i}}}; ++i; xv }
+ # Every-k-bits mask
+ {m, d} := unaligned_spaced_mask_mod{wv}
+ rep_const_bool_odd_mask4{W, wv, {}=>W~~get_swap_x{}, {v}=>output{V~~v}, cdiv{nw, vcount{W}}, m, d}
}
flush{}
}