diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index 2d584525..51b6ae65 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -293,12 +293,15 @@ exportT{'si_constrep', each{rep_const, dat_types}} fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { def has_pdep = 0 # Obselete, kept here for descriptiveness if (wv > 32) return{0} - if (not has_pdep and wv <= 8) return{0} m:u64 = spaced_mask_of{wv} xw:u64 = 0 d := cast_i{usz, popc{m}} # == 64/wv nw := cdiv{rlen, 64} - if (m&1 != 0) { # Power of two + if (wv&1 != 0) { + rep_const_bool_generic_odd{wv, x, r, nw, m, d} + } else if (not has_pdep and wv <= 8) { + return{0} + } else if (m&1 != 0) { # Power of two i := -usz~~1 def expand = if (has_pdep) pdep{., m} else { mult:u64 = spaced_mask_of{wv-1} >> d @@ -335,6 +338,105 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { 1 } +def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = { + # Every-k-bits mask + mask_sh := cast_i{usz, ctz{m}} # == 64%k + mask := m<<(k-mask_sh) | 1 + + # Transform sending bit i to k*i % 64 by pairwise swaps + # Swap data goes in a pre-computed table + def swdat{lw, k} = { + def i = iota{lw} + def bits = (k*iota{1<> merge{0, replicate{1<>1} + swsel:u64 = ~u64~~0 + def gsw{l} = { + swsel ^= swsel << l # Low l bits out of every 2*l + sm := swap_data &~ swsel + swap_data &= swsel; swap_data |= swap_data<>l) + } + def swap_step{l==32, m} = { # Use rotate + mm := m | m>>l + x = (x &~ mm) | ((x<>l) & mm) + } + each{swap_step, swap_lens, swap_masks} + x + } + + # Output + j:usz = 0 + def output{rw} = { + store{rp, j, rw} + ++j; if (j==nw) return{1} + } + o:u64 = 0 # carry + # Dedicated loop for 3, shared for other factors + if (k == 3) { + while (1) { + x := get_swap_x{} + @unroll (jj to 3) { + b := x & mask + mask = (mask << (3 - 64%3)) | (mask >> (64%3)) + def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}} + output{fold{|, (b<<3) - b, each{>>{o,.}, os}}} + o = b + } + } + } else { + # Fundamental operation: shifts act as order-k cyclic group on masks + def advance{m, sh} = m<<(k-sh) | m>>sh + sm0 := mask # starting mask + s1 := mask_sh # single iteration shift + # Get cumulative mask for 4 iterations, and shift to advance 4 + def double{{mc, s}} = { + def ss = s+s + tup{mc|advance{mc,s}, ss - (k &- (ss>k))} + } + {mc4, s4} := double{double{tup{sm0, s1}}} + # Submasks pick one mask out of a combination of 4 + def or_adv{m, s} = { m |= advance{m,s} } + @for (min{k/4 - 1, nw/4}) or_adv{sm0,s4} + submasks := scan{advance, tup{sm0, ...3**s1}} + mask_tail := advance{sm0, s4} &~ sm0 + while (1) { + x := get_swap_x{} + mask = mc4 + # Write result word given starting bits + def step{b} = { + r := (b< 0})} + o = b>>(64-k) + } + # Fast unrolled iterations + @for (k/4) { + xm := x & mask + each{{mm} => step{xm & mm}, submasks} + mask = advance{mask, s4} + } + # Single-step for tail + mask = mask_tail + @for (k%4) { + step{x & mask} + mask = advance{mask, s1} + } + } + } +} + fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { if (wv > 32) return{0} if (wv&1 == 0) { @@ -539,7 +641,7 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15 mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv # State xv:V = V**0; o:=W**0 - i:usz = 0; q:usz = 1 + i:usz = 0 while (1) { # Load xv, send bit i to position wv*i % 128 xv = load{*V~~x, i}; ++i