diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index 2cd0cc9e..be48bfb6 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -347,23 +347,38 @@ def get_boolvec_writer{V, r, nw} = { def vwords = width{V}/64 rv := *V~~r re := rv + nw / vwords + rs := rv + cdiv{nw, vwords} + # Avoid reading/processing an extra word past the end of input + def done = makelabel{} + def check_done{} = if (rv == rs) goto{done} + # If the last result is partial, jump to flush to write it last_res:V = V**0 - def end = makelabel{} + def l_flush = makelabel{} def output{v:(V)} = { last_res = v - if (rv==re) goto{end} + if (rv == re) goto{l_flush} store{rv, 0, v}; ++rv } def flush{} = { - setlabel{end} + setlabel{l_flush} q := nw & (vwords-1) if (q != 0) homMaskStoreF{rv, V~~maskOf{re_el{u64,V}, q}, last_res} + setlabel{done} } - tup{output, flush} + tup{output, check_done, flush} +} +def get_boolvec_writer{T=(u64), r:*T, nw} = { + def done = makelabel{} + j:usz = 0 + def output{rw} = { + store{r, j, rw} + ++j; if (j==nw) goto{done} + } + def flush{} = setlabel{done} + tup{output, {}=>{}, flush} } def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = { - oper // ({a,b}=>floor{a/b}) infix left 40 def avx2 = hasarch{'AVX2'} def vl = if (avx2) 32 else 16 def V = [vl]u8 @@ -372,7 +387,7 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = { def selH = sel{[16]u8, ., .} def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .} def id{xv} = xv - def {output, flush} = get_boolvec_writer{V, r, nw} + def {output, done, flush} = get_boolvec_writer{V, r, nw} def run24{x, proc_xv, exh} = { i:usz = 0; while (1) { @@ -409,6 +424,19 @@ def rep_const_bool_div8{wv, x, r, nw if hasarch{'SSSE3'}} = { flush{} } +def sel_imm{V=[4]T, x, {...inds} if hasarch{'SSE2'} and (T==u32 or T==u64)} = { + assert{length{inds} == 4} + shuf{V, x, base{4, inds}} +} +def sel_imm{V=([16]u8), x:X, {...inds}} = { + def I = re_el{u8,X}; def [n]_ = I + def l = length{inds} + assert{l==16 or l==n} + sel{V, x, make{I, cycle{n, inds}}} +} + +def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh + # Data for the permutation that sends bit i to k*i % width{T} def modperm_dat{T, k} = { def w = width{T} @@ -419,45 +447,119 @@ def modperm_dat{T, k} = { {_} => T~~base{2, bits} } } -def modperm_step{x, l, m} = { +# Permutation step evaluators +# shift takes a top-half mask; others take both-halves +def modperm_shift_step{x, l, m} = { def d = (x ^ x<>l) } -def modperm_step{x:T, l=(width{T}/2), m} = { - def mm = m | m>>l - (x &~ mm) | ((x<>l) & mm) # rotate +def modperm_rot_step{x:T, l=(width{T}/2), m} = { + (x &~ m) | ((x<>l) & m) } -def modperm_step{x:T=[_](u8), l, m} = { - def W = re_el{u64, T} - T~~modperm_step{W~~x, l, W~~m} +def modperm_shuf_step{x:V=[_]T, l, m if l%8==0} = { + (x &~ m) | (swap_elts{x, l/8} & m) } -def swap_elts{x:V=[vl](u8), el_bytes} = { # Reverse each pair of elements - def swi{len, l} = { def i = iota{len}; i + (l - 2*(i&l)) } +# Reverse each pair of elements +def swap_elts{x:V=[_]_, el_bytes} = { + def selx{n, wd} = { + def l = el_bytes/wd + def i = iota{n} + def rev = i + (l - 2*(i&l)) + sel_imm{[n]ty_u{wd*8}, x, rev} + } if (el_bytes >= 4) { - def wd = max{4, el_bytes/2} - shuf{[4]ty_u{wd*8}, x, base{4, swi{4, el_bytes/wd}}} + selx{4, max{4, el_bytes/2}} + } else if (hasarch{'SSSE3'}) { + selx{16, 1} } else { - def i = swi{16, el_bytes} - sel{[16]u8, x, make{V, cycle{vl, i}}} + def sh = el_bytes*8 + xe := re_el{ty_u{2*sh}, V} ~~ x + V~~(xe<>sh) } } -def modperm_step{x, l, m:V=[_]T if l%8==0} = { - (x & m) | (swap_elts{x, l/8} &~ m) +# Extract swap functions from modperm_dat +def extract_modperm_mask{full:W, lane, l} = { + def f = l == 128 # Use full width + def w = if (l<32) 8 else if (f) 64 else 32 # Shuffle element width + def n = (if (f) 256 else 128) / w # and number of elements + def o = l / w # Extract [o,2*o) + def i = o + iota{n}%o + sel_imm{[n]ty_u{w}, if (f) full else lane, i} } -def modperm_get_byteperm{sw_bytes:V=[_](u8) if hasarch{'SSSE3'}} = { - def shW{op, v, s} = V~~op{re_el{u64,V}~~v, s} - m4 := V**0xf - t0 := fold{{v,a}=>modperm_step{v,...a}, make{V,iota{vcount{V}}%16}, tup{ - tup{4, sw_bytes &~ m4}, - tup{2, ({v} => v|shW{<<, v, 4}){sw_bytes&(V**0xc)}} - }} - t4 := shW{<<, t0&m4, 4} | shW{>>, t0&~m4, 4} - def selI = sel{[16]u8, ., .} - {xv} => selI{t0, xv & m4} | selI{t4, shW{>>, xv, 4} & m4} +def proc_mod_dat{swap_data:W} = { + def ww = width{W}; def avx2 = ww==256 + def on_len_range{get_swap, lo, hi} = { + def lens = reverse{1 << slice{iota{lb{hi}}, lb{lo}}} + fold{{x,s}=>s{x}, ., each{get_swap, lens}} + } + swap_lane := if (avx2) shuf{W, swap_data, 4b1010} else swap_data + # Transform width-w units with shifts only + def get_shiftperm{data, w} = { + dat := data + bot:W = W**base{2, cycle{64, replicate{w, tup{1,0}}}} + def gsw{l} = { + bot ^= bot << l # Low l bits out of every 2*l + sm := dat &~ bot + dat &= bot; dat |= dat<>l} + } + on_len_range{gsw, 2, w} + } + # Within-byte transformation + def get_byteperm{} = { + def V = re_el{u8, W} + def sw_bytes = sel{[16]u8, swap_lane, V**0} + m4 := W~~V**0xf + t0 := fold{{v,a}=>modperm_shift_step{v,...a}, W~~make{V,iota{vcount{V}}%16}, tup{ + tup{4, sw_bytes &~ m4}, + tup{2, ({v} => v | v<<4){sw_bytes & W~~V**0xc}} + }} + t4 := (t0&m4)<<4 | (t0&~m4)>>4 + def selI{v,i} = sel{[16]u8, v, V~~i} + {xv} => selI{t0, xv & m4} | selI{t4, (xv>>4) & m4} + } + def {partwidth, partperm} = { + def sh{w, d} = tup{w, get_shiftperm{d, w}} + if (ww==64) sh{64, swap_data} + else if (not hasarch{'SSSE3'}) sh{32, shuf{[4]u32, swap_data, 4b0000}} + else tup{8, get_byteperm{}} + } + # Fill in higher steps + def get_mod_permuter{} = { + def get_swap{l} = { + def mask = extract_modperm_mask{swap_data, swap_lane, l} + modperm_shuf_step{., l, mask} + } + def swap_x = on_len_range{get_swap, partwidth, ww} + {x} => partperm{swap_x{x}} + } + tup{partperm, get_mod_permuter} } -def fold_multi{f, init, ...ls} = fold{{v,a}=>f{v,...a}, init, flip{ls}} -def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh +def rep_const_bool_odd{k, x, r, nw} = { + def avx2 = hasarch{'AVX2'} + def W = if (hasarch{'SSE2'}) [if (avx2) 4 else 2]u64 else u64 + + def {output, check_done, flush} = get_boolvec_writer{W, r, nw} + xp := *W~~x + def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} } + + # Modular permutation: small-k cases may use a limited permutation + # on bytes or 32-bit ints; general case uses the whole thing + swtab:*W = each{modperm_dat{W, .}, 1+2*iota{32}} + swap_data := load{swtab, k>>1} + def {partperm, get_full_permute} = proc_mod_dat{swap_data} + + def sp_max = if (hasarch{'SSSE3'} and not avx2) 8 else 4 + if (k < sp_max) { + rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output} + } else { + def get_swap_x = getter{get_full_permute{}} + rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}} + } + flush{} +} # General-case loop for odd replication factors def rep_const_bool_odd_mask4{ @@ -465,8 +567,6 @@ def rep_const_bool_odd_mask4{ k, # replication factor get_modperm_x, # permuted input output, n, # output, number of writes - mask:(u64), # starting mask - mask_sh # single iteration shift } = { def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) } def scal = ifvec{{v} => M**v} @@ -480,6 +580,8 @@ def rep_const_bool_odd_mask4{ def ss = s+s tup{mn, ss - (k &- (ss>k))} } + # Starting word mask and single-word shift + {mask, mask_sh} := unaligned_spaced_mask_mod{k} # Mask and shift for one iteration def fillmask{T} = match (T) { {(u64)} => tup{mask, mask_sh} @@ -511,7 +613,7 @@ def rep_const_bool_odd_mask4{ while (1) { x:M = get_modperm_x{} - def vrot1 = ifvec{{x} => if (w128{M}) vshl{x, x, vcount{type{x}}-1} + def vrot1 = ifvec{{x} => if (w128{M}) shuf{[4]u32, x, 4b1032} else shuf{M, x, 4b2103}} xo := x<>(64-k)} # Write result word given starting bits @@ -532,131 +634,66 @@ def rep_const_bool_odd_mask4{ } } -def rep_const_bool_odd{k, xp, rp, nw} = { - # Every-k-bits mask - {mask, mask_sh} := unaligned_spaced_mask_mod{k} - - # Transform sending bit i to k*i % 64 by pairwise swaps - # Swap data goes in a pre-computed table - swtab:*u64 = each{modperm_dat{u64, .}, 1+2*iota{32}} - def swap_lens = reverse{2 << iota{5}} - swap_data := load{swtab, k>>1} - swsel:u64 = ~u64~~0 - def gsw{l} = { - swsel ^= swsel << l # Low l bits out of every 2*l - sm := swap_data &~ swsel - swap_data &= swsel; swap_data |= swap_data<f{x}, n**init} - def masks = reps{advance_spaced_mask{k, ., 64%k}, k, mask} - def step{x}{o, m, n_over} = { - b := x & m - def os = (64-k) + 1 + iota{n_over} - output{fold{|, (b<>{o,.}, os}}} - b - } - fold_multi{step{get_swap_x{}}, 0, masks, (-iota{k}*64)%k} - } - unrolled_iter{3} - } else { - rep_const_bool_odd_mask4{u64, k, get_swap_x, output, nw, mask, mask_sh} +def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = { + def k = 3 + def step{x}{o, n_over} = { + b := x & base{2, cycle{64, n_over == iota{k}}} + def os = (64-k) + 1 + iota{n_over} + output{fold{|, (b<>{o,.}, os}}} + b } + while (1) fold{step{get_swap_x{}}, 0, (-iota{k}*64)%k} } -# For odd numbers: +# For small odd numbers: # - permute each byte sending bit i to position k*i % 8 # - replicate each byte by k, making position k*i contain bit i # - mask out those bits and spread over [ k*i, k*(i+1) ) # - ...except where it crosses words; handle this overhang separately -def rep_const_bool_odd{k, x, r, nw if hasarch{'SSSE3'}} = { - def avx2 = hasarch{'AVX2'} - def vl = if (avx2) 32 else 16; def V = [vl]u8 - def iV = iota{vl} - def mkV = make{V, .}; def selV = sel{[16]u8, ., .} - def W = re_el{u64, V} - def {output, flush} = get_boolvec_writer{W, r, nw} - - # Swap data goes in a pre-computed table - swtab:*V = each{modperm_dat{V, .}, 1+2*iota{32}} - swap_data := load{swtab, k>>1} - swap_lane := if (avx2) shuf{[4]u64, swap_data, 4b1010} else swap_data - # Within-byte transformation - def perm_x = modperm_get_byteperm{selV{swap_lane, V**0}} - def sp_max = if (avx2) 4 else 8 - if (k < sp_max) { - rep_const_bool_odd_special{V, sp_max, k, x, perm_x, {v}=>output{W~~v}} - } else { - {m, d} := unaligned_spaced_mask_mod{k} - # General case - def get_mask{l} = selV{~swap_lane, mkV{l+iV%l}} - def get_mask{16} = shuf{[4]u64, ~swap_data, 4b3232} - def swap_lens = reverse{1 << iota{4 + avx2}} - swap_masks := each{get_mask, swap_lens} - def swap_x = fold_multi{modperm_step, ., 8*swap_lens, swap_masks} - i:usz = 0 - def get_swap_x{} = { xv := W~~perm_x{swap_x{load{*V~~x, i}}}; ++i; xv } - rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}, m, d} - } - flush{} -} - -def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = { - oper // ({a,b}=>floor{a/b}) infix left 40 - def iV = iota{vl} - def mkV = make{V, .}; def selV = sel{[16]u8, ., .} - def W = re_el{u64, V} +# If 1-byte shuffle isn't available, use 4-byte units instead +def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = { + def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v } if (max_wv <= 4 or wv < 4) { - def ll = 16; def dup{v} = if (vl>ll) merge{v,v} else v + def ww = width{W} + def ew = if (hasarch{'SSSE3'}) 8 else 32 # width of shuffle-able elements + def ne = ww/ew; def se = 64/ew + def lanes = ww > 128 + def dup{v} = if (lanes) merge{v,v} else v # 3: dedicated loop - i:usz = 0; while (1) { + while (1) { # 01234567 to 05316427 on each byte - xv := perm_x{load{*V~~x, i}}; ++i + xv := get_perm_x{} # Overhang from previous 64-bit elements def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word - def ib = ix // 8 # byte index - def io = 8*ib + 3*ix%8 # where they are in xv - def wi = split{vl/8, dup{tup{255, ...ib, 255, ...8+ib}}} - xo := V~~((W~~xv & W**fold{|, 1<> (8-3)) - xo += xo > V**0 + def ib = ix // ew # byte index + def io = ew*ib + 3*ix%ew # where they are in xv + def wi = split{wl, dup{tup{255, ...ib, 255, ...se+ib}}} + xo := ov_bytes{(xv & W**fold{|, 1<> (ew-3)} # Permute and mask bytes def step{jj, oi, ind, mask} = { - def getv = if (vl==ll or jj==1) ({x}=>x) + def getv = if (not lanes or jj==1) ({x}=>x) else shuf{[4]u64, ., 4b1010 + 4b2222*(jj>1)} - def selx{x, i} = sel{[ll]u8, getv{x}, i} - b := W~~(selx{xv, ind} & mask) - r := V~~((b<<3) - b) - o := selx{xo, mkV{flat_table{max, oi, 255*(0=vl ind_up = mkV{iu - k*ia} ind_inc = mkV{vl//k + ia} mask_sh = width{V} % k } if (wv == 5) set_consts{5} else set_consts{7} - xv:V = V**0; xo:=xv; ind:=xv; mask:=xv # state - i:usz = 0; q:usz = 1 + xv:=W**0; xo:=xv; mask:=xv; ind:=V**0 # state + q:usz = 1 while (1) { --q; if (q == 0) { q = wv # Load and permute bytes - xv = perm_x{load{*V~~x, i}}; ++i + xv = get_perm_x{} # Bytes for overhang - xo = V~~((W~~xv & xom) >> (8 - wv)) - xo += xo > V**0 + xo = ov_bytes{(xv & xom) >> (8 - wv)} xo = selV{xo, xse} # Initialize state vectors ind = ind0 @@ -691,11 +727,11 @@ def rep_const_bool_odd_special{V=[vl](u8), max_wv, wv, x, perm_x, output} = { # Update state vectors xo = shr{V, xo, 1} ind = selV{ind, ind_up} + ind_inc - mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh)) + mask = advance_spaced_mask{wv, mask, mask_sh} } - b := W~~(selV{xv, ind} & mask) - rv:= V~~((b<