diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index 351a3644..2cd0cc9e 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -300,7 +300,7 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { } else { tlen := rlen>>p wq := usz~~1<

=52))) or (not hasarch{'AVX2'} and p == 1 and wv>=24)) { + if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (p == 1 and wv>=24)) { # Expanding odd second is faster tlen = rlen / wf t:=wf; wf=wq; wq=t @@ -412,8 +412,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = { # Data for the permutation that sends bit i to k*i % width{T} def modperm_dat{T, k} = { def w = width{T} - def lw= lb{w} - def i = iota{lw} + def i = iota{lb{w}} def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1< make{T, each{base{2,.}, split{width{E}, bits}}} @@ -432,24 +431,29 @@ def modperm_step{x:T=[_](u8), l, m} = { def W = re_el{u64, T} T~~modperm_step{W~~x, l, W~~m} } -def swap_elts{x:V, el_bytes if w128{V}} = { - def n = 16; def I = [n]u8 +def swap_elts{x:V=[vl](u8), el_bytes} = { # Reverse each pair of elements def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) } - if (el_bytes >= 4) shuf{[4]u32, x, base{4, swi{4, el_bytes/4}}} - else sel{I, x, make{I, swi{16, el_bytes}}} + if (el_bytes >= 4) { + def wd = max{4, el_bytes/2} + shuf{[4]ty_u{wd*8}, x, base{4, swi{4, el_bytes/wd}}} + } else { + def i = swi{16, el_bytes} + sel{[16]u8, x, make{V, cycle{vl, i}}} + } } def modperm_step{x, l, m:V=[_]T if l%8==0} = { (x & m) | (swap_elts{x, l/8} &~ m) } -def modperm_get_byteperm{sw_bytes:V=([16]u8) if hasarch{'SSSE3'}} = { +def modperm_get_byteperm{sw_bytes:V=[_](u8) if hasarch{'SSSE3'}} = { def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s} m4 := V**0xf - t0 := fold{{v,a}=>modperm_step{v,...a}, iota{V}, tup{ + t0 := fold{{v,a}=>modperm_step{v,...a}, make{V,iota{vcount{V}}%16}, tup{ tup{4, sw_bytes &~ m4}, tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}} }} t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4} - {xv} => sel{V, t0, xv & m4} | sel{V, t4, shW{>>, xv, 4} & m4} + def selI = sel{[16]u8, ., .} + {xv} => selI{t0, xv & m4} | selI{t4, shW{>>, xv, 4} & m4} } def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}} @@ -477,7 +481,12 @@ def rep_const_bool_odd_mask4{ tup{mn, ss - (k &- (ss>k))} } # Mask and shift for one iteration - {sm0, s1} := ifvec{double_gen{make{M,...}}}{tup{mask, mask_sh}} + def fillmask{T} = match (T) { + {(u64)} => tup{mask, mask_sh} + {[2]E} => double_gen{make{T,...}}{fillmask{E}} + {[4]E} => double_gen{pair}{fillmask{[2]E}} + } + {sm0, s1} := fillmask{M} # Combined mask for 4 iterations, and shift to advance 4 def double = double_gen{|} {mc4, s4} := double{double{tup{sm0, s1}}} @@ -488,7 +497,8 @@ def rep_const_bool_odd_mask4{ mask_tail := advance{sm0, s4} &~ sm0 # Carry: shifting and word-crossing is done on the initial permuted x - o:M = scal{0} # Carry for x + # No need to carry across input words since they align with output words + # First bit of each word in xo below is wrong, but it doesn't matter! mr := scal{u64~~1< { @@ -501,9 +511,9 @@ def rep_const_bool_odd_mask4{ while (1) { x:M = get_modperm_x{} - def vrot1 = ifvec{{x} => vshl{x, x, vcount{type{x}}-1}} - k1:M = scal{1} - os:=o; xo:=x<>(64-k)}; o=xo&k1; xo=(xo&~k1)|os + def vrot1 = ifvec{{x} => if (w128{M}) vshl{x, x, vcount{type{x}}-1} + else shuf{M, x, 4b2103}} + xo := x<>(64-k)} # Write result word given starting bits def step{b, c} = output{sub_carry{c - b, c & mr}} def step{b, c, m} = step{b&m, c&m} @@ -524,10 +534,7 @@ def rep_const_bool_odd_mask4{ def rep_const_bool_odd{k, xp, rp, nw} = { # Every-k-bits mask - m:u64 = spaced_mask_of{k} - d := cast_i{usz, popc{m}} # == 64/k - mask_sh := cast_i{usz, ctz{m}} # == 64%k - mask := m<<(k-mask_sh) | 1 + {mask, mask_sh} := unaligned_spaced_mask_mod{k} # Transform sending bit i to k*i % 64 by pairwise swaps # Swap data goes in a pre-computed table @@ -553,19 +560,18 @@ def rep_const_bool_odd{k, xp, rp, nw} = { store{rp, j, rw} ++j; if (j==nw) return{1} } - o:u64 = 0 # carry # Dedicated loop for 3, shared for other factors - if (k == 3) { + if (k == 3) while (1) { def unrolled_iter{k} = { def reps{f, n, init} = scan{{x,_}=>f{x}, n**init} def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask} - def step{x}{m, n_over} = { + def step{x}{o, m, n_over} = { b := x & m def os = (64-k) + 1 + iota{n_over} output{fold{|, (b<>{o,.}, os}}} - o = b + b } - while (1) each{step{get_swap_x{}}, masks, (-iota{k}*64)%k} + fold_multi{step{get_swap_x{}}, 0, masks, (-iota{k}*64)%k} } unrolled_iter{3} } else { @@ -578,21 +584,45 @@ def rep_const_bool_odd{k, xp, rp, nw} = { # - replicate each byte by k, making position k*i contain bit i # - mask out those bits and spread over [ k*i, k*(i+1) ) # - ...except where it crosses words; handle this overhang separately -def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15 - oper // ({a,b}=>floor{a/b}) infix left 40 - def vl = 16; def V = [vl]u8 +def rep_const_bool_odd{k, x, r, nw if hasarch{'SSSE3'}} = { + def avx2 = hasarch{'AVX2'} + def vl = if (avx2) 32 else 16; def V = [vl]u8 def iV = iota{vl} - def mkV = make{V, .}; def selV = sel{V, ., .} - def W = [2]u64 - def {output, flush} = get_boolvec_writer{V, r, nw} + def mkV = make{V, .}; def selV = sel{[16]u8, ., .} + def W = re_el{u64, V} + def {output, flush} = get_boolvec_writer{W, r, nw} # Swap data goes in a pre-computed table swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}} - swap_data := load{swtab, wv>>1} + swap_data := load{swtab, k>>1} + swap_lane := if (avx2) shuf{[4]u64, swap_data, 4b1010} else swap_data # Within-byte transformation - def perm_x = modperm_get_byteperm{selV{swap_data, V**0}} + def perm_x = modperm_get_byteperm{selV{swap_lane, V**0}} + def sp_max = if (avx2) 4 else 8 + if (k < sp_max) { + rep_const_bool_odd_special{V, sp_max, k, x, perm_x, {v}=>output{W~~v}} + } else { + {m, d} := unaligned_spaced_mask_mod{k} + # General case + def get_mask{l} = selV{~swap_lane, mkV{l+iV%l}} + def get_mask{16} = shuf{[4]u64, ~swap_data, 4b3232} + def swap_lens = reverse{1 << iota{4 + avx2}} + swap_masks := each{get_mask, swap_lens} + def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks} + i:usz = 0 + def get_swap_x{} = { xv := W~~perm_x{swap_x{load{*V~~x, i}}}; ++i; xv } + rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}, m, d} + } + flush{} +} - if (wv < 4) { +def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = { + oper // ({a,b}=>floor{a/b}) infix left 40 + def iV = iota{vl} + def mkV = make{V, .}; def selV = sel{[16]u8, ., .} + def W = re_el{u64, V} + if (max_wv <= 4 or wv < 4) { + def ll = 16; def dup{v} = if (vl>ll) merge{v,v} else v # 3: dedicated loop i:usz = 0; while (1) { # 01234567 to 05316427 on each byte @@ -601,24 +631,29 @@ def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15 def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word def ib = ix // 8 # byte index def io = 8*ib + 3*ix%8 # where they are in xv - def wi = split{2, tup{255, ...ib, 255, ...8+ib}} + def wi = split{vl/8, dup{tup{255, ...ib, 255, ...8+ib}}} xo := V~~((W~~xv & W**fold{|, 1<> (8-3)) xo += xo > V**0 # Permute and mask bytes def step{jj, oi, ind, mask} = { - b := W~~(selV{xv, ind} & mask) + def getv = if (vl==ll or jj==1) ({x}=>x) + else shuf{[4]u64, ., 4b1010 + 4b2222*(jj>1)} + def selx{x, i} = sel{[ll]u8, getv{x}, i} + b := W~~(selx{xv, ind} & mask) r := V~~((b<<3) - b) - o := selV{xo, mkV{flat_table{max, oi, 255*(0 selV{~swap_data, mkV{l+iV%l}}, swap_lens} - def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks} - i:usz = 0 - def get_swap_x{} = { xv := perm_x{swap_x{load{*V~~x, i}}}; ++i; xv } - # Every-k-bits mask - {m, d} := unaligned_spaced_mask_mod{wv} - rep_const_bool_odd_mask4{W, wv, {}=>W~~get_swap_x{}, {v}=>output{V~~v}, cdiv{nw, vcount{W}}, m, d} } - flush{} } + export{'si_constrep_bool', rep_const_bool{}}