From 2be022921e858bd70a3ec9b8fb6bbdc15f3e3e25 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Sat, 10 Aug 2024 14:20:22 -0400 Subject: [PATCH] Combine generic and SSSE3 general-case odd/bool, so SSSE3 uses mask unpacking --- src/singeli/src/replicate.singeli | 367 +++++++++++++++--------------- 1 file changed, 185 insertions(+), 182 deletions(-) diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index 6eec5b65..351a3644 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -300,7 +300,8 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { } else { tlen := rlen>>p wq := usz~~1<

=52)) { # Expanding odd second is faster + if ((not hasarch{'SSSE3'} and (p == 1 or (p == 2 and wv>=52))) or (not hasarch{'AVX2'} and p == 1 and wv>=24)) { + # Expanding odd second is faster tlen = rlen / wf t:=wf; wf=wq; wq=t } @@ -309,14 +310,12 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { rep_const_bool{}(wq, t, r, rlen) } } else { - m:u64 = spaced_mask_of{wv} - d := cast_i{usz, popc{m}} # == 64/wv - rep_const_bool_generic_odd{wv, x, r, nw, m, d} + rep_const_bool_odd{wv, x, r, nw} } 1 } -def rep_const_bool_div8{wv, x, r, nw} = { +def rep_const_bool_div8{wv, x, r, nw} = { # wv in 2,4,8 def run{k} = { # 2 -> 64w0x33, 12 -> 64w0x000f, etc. def getm{sh} = base{2, iota{64}&sh == 0} @@ -344,133 +343,8 @@ def rep_const_bool_div8{wv, x, r, nw} = { cases{2} } -def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = { - # Every-k-bits mask - mask_sh := cast_i{usz, ctz{m}} # == 64%k - mask := m<<(k-mask_sh) | 1 - - # Transform sending bit i to k*i % 64 by pairwise swaps - # Swap data goes in a pre-computed table - def swdat{lw, k} = { - def i = iota{lw} - def bits = (k*iota{1<> merge{0, replicate{1<>1} - swsel:u64 = ~u64~~0 - def gsw{l} = { - swsel ^= swsel << l # Low l bits out of every 2*l - sm := swap_data &~ swsel - swap_data &= swsel; swap_data |= swap_data<>l) - } - def swap_step{l==32, m} = { # Use rotate - mm := m | m>>l - x = (x &~ mm) | ((x<>l) & mm) - } - each{swap_step, swap_lens, swap_masks} - x - } - - # Output - j:usz = 0 - def output{rw} = { - store{rp, j, rw} - ++j; if (j==nw) return{1} - } - o:u64 = 0 # carry - # Dedicated loop for 3, shared for other factors - if (k == 3) { - while (1) { - x := get_swap_x{} - @unroll (jj to 3) { - b := x & mask - mask = (mask << (3 - 64%3)) | (mask >> (64%3)) - def os = (64-3) + 1 + iota{select{tup{0,2,1}, jj}} - output{fold{|, (b<<3) - b, each{>>{o,.}, os}}} - o = b - } - } - } else { - # Fundamental operation: shifts act as order-k cyclic group on masks - def advance{m, sh} = m<<(k-sh) | m>>sh - sm0 := mask # starting mask - s1 := mask_sh # single iteration shift - # Get cumulative mask for 4 iterations, and shift to advance 4 - def double{{mc, s}} = { - def ss = s+s - tup{mc|advance{mc,s}, ss - (k &- (ss>k))} - } - {mc4, s4} := double{double{tup{sm0, s1}}} - # Submasks pick one mask out of a combination of 4 - def or_adv{m, s} = { m |= advance{m,s} } - @for (min{k/4 - 1, nw/4}) or_adv{sm0,s4} - submasks := scan{advance, tup{sm0, ...3**s1}} - mask_tail := advance{sm0, s4} &~ sm0 - # Mask out carry bit - mr := u64~~1<>(64-k); o=xo&1; xo=(xo&~1)|os - mask = mc4 - # Write result word given starting bits - def step{b, c} = output{c - b - promote{u64, c&mr != 0}} - def step{b, c, m} = step{b&m, c&m} - # Fast unrolled iterations - @for (k/4) { - each{step{x & mask, xo & mask, .}, submasks} - mask = advance{mask, s4} - } - # Single-step for tail - mask = mask_tail - @for (k%4) { - step{x, xo, mask} - mask = advance{mask, s1} - } - } - } -} - -fn rep_const_bool{if hasarch{'SSSE3'}}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = { - if (wv > 32) return{0} - if (wv&1 == 0) { - p := ctz{wv | 8} # Power of two for second replicate - if (wv>>p == 1) { - rep_const_bool_div8{wv, x, r, rlen} - } else { - tlen := rlen>>p - t := r + cdiv{rlen, 64} - cdiv{tlen, 64} - rep_const_bool{}(wv >>p, x, t, tlen) - rep_const_bool{}(usz~~1<floor{a/b}) infix left 40 def avx2 = hasarch{'AVX2'} def vl = if (avx2) 32 else 16 @@ -498,7 +372,7 @@ def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8 def selH = sel{[16]u8, ., .} def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .} def id{xv} = xv - def {output, flush} = get_boolvec_writer{V, r, rlen} + def {output, flush} = get_boolvec_writer{V, r, nw} def run24{x, proc_xv, exh} = { i:usz = 0; while (1) { @@ -535,29 +409,189 @@ def rep_const_bool_div8{wv, x, r, rlen if hasarch{'SSSE3'}} = { # wv in 2,4,8 flush{} } +# Data for the permutation that sends bit i to k*i % width{T} +def modperm_dat{T, k} = { + def w = width{T} + def lw= lb{w} + def i = iota{lw} + def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1< make{T, each{base{2,.}, split{width{E}, bits}}} + {_} => T~~base{2, bits} + } +} +def modperm_step{x, l, m} = { + def d = (x ^ x<>l) +} +def modperm_step{x:T, l=(width{T}/2), m} = { + def mm = m | m>>l + (x &~ mm) | ((x<>l) & mm) # rotate +} +def modperm_step{x:T=[_](u8), l, m} = { + def W = re_el{u64, T} + T~~modperm_step{W~~x, l, W~~m} +} +def swap_elts{x:V, el_bytes if w128{V}} = { + def n = 16; def I = [n]u8 + def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) } + if (el_bytes >= 4) shuf{[4]u32, x, base{4, swi{4, el_bytes/4}}} + else sel{I, x, make{I, swi{16, el_bytes}}} +} +def modperm_step{x, l, m:V=[_]T if l%8==0} = { + (x & m) | (swap_elts{x, l/8} &~ m) +} +def modperm_get_byteperm{sw_bytes:V=([16]u8) if hasarch{'SSSE3'}} = { + def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s} + m4 := V**0xf + t0 := fold{{v,a}=>modperm_step{v,...a}, iota{V}, tup{ + tup{4, sw_bytes &~ m4}, + tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}} + }} + t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4} + {xv} => sel{V, t0, xv & m4} | sel{V, t4, shW{>>, xv, 4} & m4} +} +def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}} + +def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh + +# General-case loop for odd replication factors +def rep_const_bool_odd_mask4{ + M, # read/write type + k, # replication factor + get_modperm_x, # permuted input + output, n, # output, number of writes + mask:(u64), # starting mask + mask_sh # single iteration shift +} = { + def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) } + def scal = ifvec{{v} => M**v} + + # Fundamental operation: shifts act as order-k cyclic group on masks + def advance{m, sh} = advance_spaced_mask{k, m, sh} + # Double a cumulative mask, shift combination + # If s advances l iterations, mc combines iterations iota{l} + def double_gen{comb}{{mc, s}} = { + def mn = comb{mc, advance{mc,s}} + def ss = s+s + tup{mn, ss - (k &- (ss>k))} + } + # Mask and shift for one iteration + {sm0, s1} := ifvec{double_gen{make{M,...}}}{tup{mask, mask_sh}} + # Combined mask for 4 iterations, and shift to advance 4 + def double = double_gen{|} + {mc4, s4} := double{double{tup{sm0, s1}}} + # Submasks pick one mask out of a combination of 4 + def or_adv{m, s} = { m |= advance{m,s} } + @for (min{k/4 - 1, n/4}) or_adv{sm0,s4} + submasks := scan{advance, tup{sm0, ...3**s1}} + mask_tail := advance{sm0, s4} &~ sm0 + + # Carry: shifting and word-crossing is done on the initial permuted x + o:M = scal{0} # Carry for x + mr := scal{u64~~1< { + ca := if (hasarch{'SSE4.2'}) { def S = [l]i64; S~~c > S**0 } + else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} } + a + M~~ca + } + {_} => a - promote{u64, c != 0} + } + + while (1) { + x:M = get_modperm_x{} + def vrot1 = ifvec{{x} => vshl{x, x, vcount{type{x}}-1}} + k1:M = scal{1} + os:=o; xo:=x<>(64-k)}; o=xo&k1; xo=(xo&~k1)|os + # Write result word given starting bits + def step{b, c} = output{sub_carry{c - b, c & mr}} + def step{b, c, m} = step{b&m, c&m} + # Fast unrolled iterations + mask := mc4 + @for (k/4) { + each{step{x & mask, xo & mask, .}, submasks} + mask = advance{mask, s4} + } + # Single-step for tail + mask = mask_tail + @for (k%4) { + step{x, xo, mask} + mask = advance{mask, s1} + } + } +} + +def rep_const_bool_odd{k, xp, rp, nw} = { + # Every-k-bits mask + m:u64 = spaced_mask_of{k} + d := cast_i{usz, popc{m}} # == 64/k + mask_sh := cast_i{usz, ctz{m}} # == 64%k + mask := m<<(k-mask_sh) | 1 + + # Transform sending bit i to k*i % 64 by pairwise swaps + # Swap data goes in a pre-computed table + swtab:*u64 = each{modperm_dat{u64, .}, 1+2*iota{32}} + def swap_lens = reverse{2 << iota{5}} + swap_data := load{swtab, k>>1} + swsel:u64 = ~u64~~0 + def gsw{l} = { + swsel ^= swsel << l # Low l bits out of every 2*l + sm := swap_data &~ swsel + swap_data &= swsel; swap_data |= swap_data<f{x}, n**init} + def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask} + def step{x}{m, n_over} = { + b := x & m + def os = (64-k) + 1 + iota{n_over} + output{fold{|, (b<>{o,.}, os}}} + o = b + } + while (1) each{step{get_swap_x{}}, masks, (-iota{k}*64)%k} + } + unrolled_iter{3} + } else { + rep_const_bool_odd_mask4{u64, k, get_swap_x, output, nw, mask, mask_sh} + } +} + # For odd numbers: # - permute each byte sending bit i to position k*i % 8 # - replicate each byte by k, making position k*i contain bit i # - mask out those bits and spread over [ k*i, k*(i+1) ) # - ...except where it crosses words; handle this overhang separately -def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15 +def rep_const_bool_odd{wv, x, r, nw if hasarch{'SSSE3'}} = { # wv odd, wv<=15 oper // ({a,b}=>floor{a/b}) infix left 40 def vl = 16; def V = [vl]u8 def iV = iota{vl} def mkV = make{V, .}; def selV = sel{V, ., .} def W = [2]u64 - def {output, flush} = get_boolvec_writer{V, r, rlen} + def {output, flush} = get_boolvec_writer{V, r, nw} + # Swap data goes in a pre-computed table + swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}} + swap_data := load{swtab, wv>>1} # Within-byte transformation - def get_ttab{k} = each{{is} => mkV{tr_iota{is}}, split{4, k*iota{8} % 8}} - ttab:*V = join{each{get_ttab, 2*iota{4} + 1}} - {t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}} - m4 := V**0xf - def perm_x{xv} = { - selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4} - } + def perm_x = modperm_get_byteperm{selV{swap_data, V**0}} - # Cases are 3; 5 7; 9 11 13 15 if (wv < 4) { # 3: dedicated loop i:usz = 0; while (1) { @@ -629,47 +663,16 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15 o := xo & mkV{255 * (iV%8 == 0)} # overhang output{rv | o} } - } else { # wv < 32 - # 9 to 31: extend k*i % 8 transform to k*i % 128 by pairwise swaps - # Swap data goes in a pre-computed table - def swdat{k} = { - def bits = (k*iota{128} >> merge{0, replicate{1<>1)-4} - swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens} - # Every-k-bits mask, same as before - {m, d} := unaligned_spaced_mask_mod{wv} - mask := make{W, m, m>>d|m<<(wv-d)} - mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv - # Mask out carry bit - mr := [4]u32~~W**(u64~~1< selV{~swap_data, mkV{l+iV%l}}, swap_lens} + def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks} i:usz = 0 - while (1) { - # Load xv, send bit i to position wv*i % 128 - xv = load{*V~~x, i}; ++i - def swap_step{l, m} = { - xv = (xv & m) | (selV{xv, mkV{iV + (l-2*(iV&l))}} &~ m) - } - each{swap_step, swap_lens, swap_masks} - xv = perm_x{xv} - xw := W~~xv - def vrot1{x} = vshl{x, x, vcount{type{x}}-1} - w1 := W**1 - os:=o; xo:=xw<>(64-wv)}; o=xo&w1; xo=(xo&~w1)|os - # Write wv vectors based on that - @for (wv) { - b := xw & mask - # Handle overhang here; won't fit in a single vector - c := xo & mask; cu := [4]u32~~c - output{V~~(W~~(cu + (mr&cu > [4]u32**0)) - b)} - mask = (mask << (wv - mask_sh)) | (mask >> mask_sh) - } - } + def get_swap_x{} = { xv := perm_x{swap_x{load{*V~~x, i}}}; ++i; xv } + # Every-k-bits mask + {m, d} := unaligned_spaced_mask_mod{wv} + rep_const_bool_odd_mask4{W, wv, {}=>W~~get_swap_x{}, {v}=>output{V~~v}, cdiv{nw, vcount{W}}, m, d} } flush{} }