diff --git a/src/builtins/slash.c b/src/builtins/slash.c index bad784ba..5e883113 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -38,8 +38,13 @@ // COULD consolidate refcount updates for nested 𝕩 // Replicate by constant -// Boolean uses pdep, ≠`, or overwriting -// SHOULD make a shift/mask replacement for pdep +// Boolean uses specialized small-𝕨 methods, ≠`, or overwriting +// 𝕨≤64: Singeli generic and SIMD methods +// 𝕨=2,4,8: Various shift, shuffle, and zip-based loops +// odd 𝕨: Modular permutation +// COULD use pdep or similar to avoid overhead on small results +// Otherwise, factor into power of 2 times odd +// COULD fuse 2×odd, since 2/odd/ has a larger intermediate // Other typed 𝕩 uses +`, or lots of Singeli // Fixed shuffles, factorization, partial shuffles, self-overlapping // Otherwise, cell-by-cell copying @@ -82,6 +87,8 @@ extern void (*const si_scan_max_i32)(int32_t* v0,int32_t* v1,uint64_t v2); #define SINGELI_FILE slash #include "../utils/includeSingeli.h" + extern uint64_t* const si_spaced_masks; + #define get_spaced_mask(i) si_spaced_masks[i-1] #define SINGELI_FILE replicate #include "../utils/includeSingeli.h" #endif @@ -765,38 +772,9 @@ B slash_c2(B t, B w, B x) { if (xl == 0) { u64* xp = bitarr_ptr(x); u64* rp; r = m_bitarrv(&rp, s); - #if FAST_PDEP - if (wv <= 52) { - #if SINGELI - u64 m = si_spaced_masks[wv-1]; - #else - u64 m = (u64)-1 / (((u64)1<>= d; - if ((j&(wv-1))==0) xw = xp[++i]; - u64 rw = _pdep_u64(xw, m); - rp[j] = (rw<> (xi%8); - u64 ex = (xw&mt)<>o|(xw&1)); - o += q; - bool oo = o>=wv; xi+=d+oo; o-=wv&-oo; - } - } - goto atmW_maybesh; - } + #if SINGELI + if (wv <= 64) si_constrep_bool(wv, xp, rp, s); + else #endif if (wv <= 256) { BOOL_REP_XOR_SCAN(wv) } else { BOOL_REP_OVER(wv, xlen) } diff --git a/src/singeli/src/replicate.singeli b/src/singeli/src/replicate.singeli index d4fe0aba..8b80c07a 100644 --- a/src/singeli/src/replicate.singeli +++ b/src/singeli/src/replicate.singeli @@ -1,5 +1,6 @@ include './base' include './mask' +include './spaced' def ind_types = tup{i8, i16, i32} def dat_types = tup{...ind_types, u64} @@ -58,7 +59,10 @@ exportT{'si_replicate_scan', flat_table{rep_by_scan, ind_types, dat_types}} # Constant replicate -if_inline (not (hasarch{'AVX2'} or hasarch{'AARCH64'})) { +def incl{a,b} = slice{iota{b+1},a} +def basic_rep = incl{2, 7} + +if_inline (not (hasarch{'SSSE3'} or hasarch{'AARCH64'})) { fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = { rep_by_scan{T, cast_i{usz,wv}, x, r, cast_i{usz, wv*n}} @@ -66,21 +70,18 @@ fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = { } else { -def incl{a,b} = slice{iota{b+1},a} +def has_bytesel_128 = not hasarch{'AVX2'} # 1+˝∨`⌾⌽0=div|⌜range def makefact{divisor, range} = { def t = table{{a,b}=>0==b%a, divisor, range} fold{+, 1, reverse{scan{|, reverse{t}}}} } -def basic_rep = incl{2, 7} def fact_size = 128 def fact_inds = slice{iota{fact_size},8} def fact_tab = makefact{basic_rep, fact_inds} factors:*u8 = fact_tab - - def sdtype = [arch_defvw/8]i8 # shuf data type def get_shufs{step, wv, len} = { def i = iota{len*step} @@ -90,7 +91,7 @@ def get_shuf_data{wv, len} = get_shufs{vcount{sdtype}, wv, len} # [len] byte-sel def get_shuf_data{wv} = get_shuf_data{wv, wv} # all shuffle vectors for 𝕨≤7 -def special_2 = ~hasarch{'AARCH64'} # handle 2 specially on x86-64 +def special_2 = not has_bytesel_128 # handle 2 specially on AVX2 def rcsh_vals = slice{basic_rep, special_2} rcsh_offs:*u8 = shiftright{0, scan{+,rcsh_vals}} rcsh_data:*i8 = join{join{each{get_shuf_data, rcsh_vals}}} @@ -106,7 +107,7 @@ def read_shuf_vecs{l, ellw:(u64), shp:*V} = { # tuple of byte selectors in 1< gen{sel{[16]u8, x, s}}, sh} @@ -191,7 +192,7 @@ fn rep_const_shuffle_partial4(wv:u64, ellw:u64, x:*i8, r:*i8, n:u64) : void = { def wvb = wv << ellw def hs = (h*step) / wvb # Actual step size in argument elements def shufbase{i if hasarch{'AVX2'}} = shuf{[4]u64, load{*V~~(x+i)}, 4b1010} - def shufbase{i if hasarch{'AARCH64'}} = load{*V~~(x+i)} + def shufbase{i if has_bytesel_128} = load{*V~~(x+i)} def shufrun{a, s} = sel{[16]i8, a, s} # happens to be the same across AVX2 & NEON i:u64 = 0 @@ -284,3 +285,461 @@ fn rep_const{T}(wv:u64, x:*void, r:*void, n:u64) : void = { } exportT{'si_constrep', each{rep_const, dat_types}} + + + +# Constant replicate on boolean +def any_sel = hasarch{'SSSE3'} or hasarch{'AARCH64'} +if_inline (hasarch{'AARCH64'}) { + def __shl{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,b} + def __shr{a:V=[_]T, b:U if not isvec{U}} = a << V**cast_i{T,-b} +} + +fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : void = { + assert{wv >= 2}; assert{wv <= 64} + nw := cdiv{rlen, 64} + if (wv&1 == 0) { + p := ctz{wv | 8} # Power of two for second replicate + wf := wv>>p + if (wf == 1) { + rep_const_bool_div8{wv, x, r, nw} + } else if (any_sel and p == 3 and wv <= 8*select{basic_rep, -1}) { + # (higher wv double-factors, which doesn't work with in-place pointers) + tlen := rlen / wf + t := r + cdiv{rlen, 64} - cdiv{tlen, 64} + rep_const_bool{}(8, x, t, tlen) + rep_const{select{dat_types,0}}(promote{u64,wf}, *void~~t, *void~~r, promote{u64,tlen/8}) + } else { + tlen := rlen >> p + wq := usz~~1 << p + if (p == 1 and (not any_sel or wv>=24)) { + # Expanding odd second is faster + tlen = rlen / wf + t:=wf; wf=wq; wq=t + } + t := r + cdiv{rlen, 64} - cdiv{tlen, 64} + rep_const_bool{}(wf, x, t, tlen) + rep_const_bool{}(wq, t, r, rlen) + } + } else { + rep_const_bool_odd{wv, x, r, nw} + } +} + +def rep_const_bool_div8{wv, x, r, nw} = { # wv in 2,4,8 + def run{k} = { + # 2 -> 64w0x33, 12 -> 64w0x000f, etc. + def getm{sh} = base{2, iota{64}&sh == 0} + def osh{v, s} = v | v< fold{ + {v, sh} => osh{v, sh} & getm{sh}, + ., 1 << reverse{iota{5}} + } + {4} => fold{ + {v, sh} => osh{osh{v, sh}, 2*sh} & getm{sh}, + ., tup{12, 3} + } + {8} => { + def mult = base{1<<7, 8**1} + {x} => (x | ((x&~1) * mult)) & 64w0x01 + } + } + @for (xt in *ty_u{64/k}~~x, r over nw) { + def v = expand{promote{u64, xt}} + r = v<{}, flush} +} + +def rep_const_bool_div8{wv, x, r, nw if has_simd} = { + def avx2 = hasarch{'AVX2'} + def vl = if (avx2) 32 else 16 + def V = [vl]u8 + def iV = iota{vl} + def mkV = make{V, .} + # Not the same as a 1-byte shift but extra bits are always masked away + def H = re_el{u16, V} + def __shl{x:(V), a} = V~~(H~~x << a) + def __shr{x:(V), a} = V~~(H~~x >> a) + + def {output, check_done, flush} = get_boolvec_writer{V, r, nw} + def run24{x, get_halves} = { + i:usz = 0; while (1) { check_done{} + xv := load{*V~~(x+i)}; ++i + def getr = zip128{...get_halves{xv}, .} + output{V~~getr{0}} + output{V~~getr{1}} + } + } + def run24{x, pre, exh, sh} = run24{x, + {xv} => { p := pre{xv}; tup{exh{p}, exh{p>>sh}} } + } + def run8{rep_bytes} = { + i:usz = 0; while (1) { check_done{} + xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i + xv := if (avx2) pair{xh, xh} else xh + xe := rep_bytes{xv} + output{(xe & mkV{1 << (iV % 8)}) > V**0} + } + } + if (not any_sel) { + if (wv == 2) { + run24{*V~~x, {xv} => { + def swap{x, l, m} = { def d = (x ^ x>>l) & V**m; x ^ (d|d<>1 + tup{o0, o1} + }} + } else if (wv == 4) { + def pre{xv} = { + a := xv ^ ((xv & V**0x55) << 1) + e := zip128{a, a>>4, 0} + } + def out{h} = (-(h & V**1)) ^ (-(h<<3 & V**0x10)) + run24{*u64~~x, pre, out, 2} + } else { # wv == 8 + def z{x} = zip{x,x,0} + run8{{x} => z{z{z{x}}}} + } + } else { # any_sel + def selH = sel{[16]u8, ., .} + def makeTab{t} = selH{mkV{if (avx2) merge{t,t} else t}, .} + def id{xv} = xv + if (wv == 2) { + def init = if (avx2) shuf{[4]u64, ., 4b3120} else id + # Expander for half byte + def tabr = makeTab{tr_iota{2*iota{4}} * 2b11} + m4 := V**0xf + run24{*V~~x, init, {x} => tabr{x & m4}, 4} + } else if (wv == 4) { + # Unzip 32-bit elements (result lanes) across AVX2 lanes + def pre = if (avx2) sel{[8]u32, ., make{[8]u32,tr_iota{1,2,0}}} else id + def init{xv} = { u:=pre{xv}; zip128{u,u,0} } + # Expander for two bits in either bottom or next-to-bottom position + def tabr = makeTab{tr_iota{0,4,0,4} * 2b1111} + m2 := mkV{2b11 << (2*(iV%2))} + def exh{x} = re_el{u16, V}~~tabr{x & m2} + run24{*(if (avx2) [2]u64 else u64)~~x, init, exh, 4} + } else { # wv == 8 + run8{selH{., mkV{iV // 8}}} + } + } + flush{} +} + +def sel_imm{V=[4]T, x, {...inds} if hasarch{'SSE2'} and (T==u32 or T==u64)} = { + assert{length{inds} == 4} + shuf{V, x, base{4, inds}} +} +def sel_imm{V=([16]u8), x:X, {...inds}} = { + def I = re_el{u8,X}; def [n]_ = I + def l = length{inds} + assert{l==16 or l==n} + sel{V, x, make{I, cycle{n, inds}}} +} + +# Data for the permutation that sends bit i to k*i % width{T} +def modperm_dat{T, k} = { + def w = width{T} + def i = iota{lb{w}} + def bits = ~(1 & (k*iota{w} >> merge{0, replicate{1< make{T, each{base{2,.}, split{width{E}, bits}}} + {_} => T~~base{2, bits} + } +} +# Permutation step evaluators +# shift takes a top-half mask; others take both-halves +def modperm_shift_step{x, l, m} = { + def d = (x ^ x<>l) +} +def modperm_rot_step{x:T, l=(width{T}/2), m} = { + (x &~ m) | ((x<>l) & m) +} +def modperm_shuf_step{x:V=[_]T, l, m if l%8==0} = { + (x &~ m) | (swap_elts{x, l/8} & m) +} +# Reverse each pair of elements +def swap_elts{x:V=[_]_, el_bytes} = { + def selx{n, wd} = { + def l = el_bytes/wd + def i = iota{n} + def rev = i + (l - 2*(i&l)) + sel_imm{[n]ty_u{wd*8}, x, rev} + } + if (hasarch{'SSE2'} and el_bytes >= 4) { + selx{4, max{4, el_bytes/2}} + } else if (any_sel) { + selx{16, 1} + } else { + def sh = el_bytes*8 + xe := re_el{ty_u{2*sh}, V} ~~ x + V~~(xe<>sh) + } +} +# Extract swap functions from modperm_dat +def extract_modperm_mask{full:W, lane, l} = { + def f = l == 128 # Use full width + def w = if (l<32 or not hasarch{'SSE2'}) 8 + else if (f) 64 else 32 # Shuffle element width + def n = (if (f) 256 else 128) / w # and number of elements + def o = l / w # Extract [o,2*o) + def i = o + iota{n}%o + sel_imm{[n]ty_u{w}, if (f) full else lane, i} +} +def proc_mod_dat{swap_data:W} = { + def ww = width{W}; def avx2 = ww==256 + def on_len_range{get_swap, lo, hi} = { + def lens = reverse{1 << slice{iota{lb{hi}}, lb{lo}}} + fold{{x,s}=>s{x}, ., each{get_swap, lens}} + } + swap_lane := if (avx2) shuf{W, swap_data, 4b1010} else swap_data + # Transform width-w units with shifts only + def get_shiftperm{data, w} = { + dat := data + bot:W = W**base{2, cycle{64, replicate{w, tup{1,0}}}} + def gsw{l} = { + bot ^= bot << l # Low l bits out of every 2*l + sm := dat &~ bot + dat &= bot; dat |= dat<>l} + } + on_len_range{gsw, 2, w} + } + # Within-byte transformation + def get_byteperm{} = { + def V = re_el{u8, W} + def sw_bytes = sel{[16]u8, swap_lane, V**0} + m4 := W~~V**0xf + t0 := fold{{v,a}=>modperm_shift_step{v,...a}, W~~make{V,iota{vcount{V}}%16}, tup{ + tup{4, sw_bytes &~ m4}, + tup{2, ({v} => v | v<<4){sw_bytes & W~~V**0xc}} + }} + t4 := (t0&m4)<<4 | (t0&~m4)>>4 + def selI{v,i} = sel{[16]u8, v, V~~i} + {xv} => selI{t0, xv & m4} | selI{t4, (xv>>4) & m4} + } + def {partwidth, partperm} = { + def sh{w, d} = tup{w, get_shiftperm{d, w}} + if (ww==64) sh{64, swap_data} + else if (not any_sel) sh{32, shuf{[4]u32, swap_data, 4b0000}} + else tup{8, get_byteperm{}} + } + # Fill in higher steps + def get_mod_permuter{} = { + def get_swap{l} = { + def mask = extract_modperm_mask{swap_data, swap_lane, l} + modperm_shuf_step{., l, mask} + } + def swap_x = on_len_range{get_swap, partwidth, ww} + {x} => partperm{swap_x{x}} + } + tup{partperm, get_mod_permuter} +} + +def rep_const_bool_odd{k, x, r, nw} = { + def avx2 = hasarch{'AVX2'} + def W = if (has_simd) [if (avx2) 4 else 2]u64 else u64 + + def {output, check_done, flush} = get_boolvec_writer{W, r, nw} + xp := *W~~x + def getter{perm}{} = { check_done{}; xv := load{xp}; ++xp; perm{xv} } + + # Modular permutation: small-k cases may use a limited permutation + # on bytes or 32-bit ints; general case uses the whole thing + swtab:*W = each{modperm_dat{W, .}, 1+2*iota{32}} + swap_data := load{swtab, k>>1} + def {partperm, get_full_permute} = proc_mod_dat{swap_data} + + def sp_max = if (any_sel) 8 else 4 + if (k < sp_max) { + rep_const_bool_small_odd{W, sp_max, k, getter{partperm}, output} + } else { + def get_swap_x = getter{get_full_permute{}} + rep_const_bool_odd_mask4{W, k, get_swap_x, output, cdiv{nw, vcount{W}}} + } + flush{} +} + +# General-case loop for odd replication factors +def rep_const_bool_odd_mask4{ + M, # read/write type + k, # replication factor + get_modperm_x, # permuted input + output, n, # output, number of writes +} = { + def ifvec{g} = match (M) { {[_](u64)} => g; {_} => ({v}=>v) } + def scal = ifvec{{v} => M**v} + + # Fundamental operation: shifts act as order-k cyclic group on masks + def advance{m, sh} = advance_spaced_mask{k, m, sh} + # Double a cumulative mask, shift combination + # If s advances l iterations, mc combines iterations iota{l} + def double_gen{comb}{{mc, s}} = { + def mn = comb{mc, advance{mc,s}} + def ss = s+s + tup{mn, ss - (k &- (ss>k))} + } + # Starting word mask and single-word shift + {mask, mask_sh} := unaligned_spaced_mask_mod{k} + # Mask and shift for one iteration + def fillmask{T} = match (T) { + {(u64)} => tup{mask, mask_sh} + {[2]E} => double_gen{make{T,...}}{fillmask{E}} + {[4]E} => double_gen{pair}{fillmask{[2]E}} + } + {sm0, s1} := fillmask{M} + # Combined mask for 4 iterations, and shift to advance 4 + def double = double_gen{|} + {mc4, s4} := double{double{tup{sm0, s1}}} + # Submasks pick one mask out of a combination of 4 + def or_adv{m, s} = { m |= advance{m,s} } + @for (min{k/4 - 1, n/4}) or_adv{sm0,s4} + submasks := scan{advance, tup{sm0, ...3**s1}} + mask_tail := advance{sm0, s4} &~ sm0 + + # Carry: shifting and word-crossing is done on the initial permuted x + # No need to carry across input words since they align with output words + # First bit of each word in xo below is wrong, but it doesn't matter! + mr := scal{u64~~1< { + ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 } + else { def S = [2*l]i32; cm := S~~c; cm != shuf{S, cm, 4b2301} } + a + M~~ca + } + {_} => a - promote{u64, c != 0} + } + + while (1) { + x:M = get_modperm_x{} + def vrot1 = ifvec{{x} => { + if (w256{M}) shuf{M, x, 4b2103} + else if (any_sel) vshl{x, x, vcount{type{x}}-1} + else shuf{[4]u32, x, 4b1032} + }} + xo := x<>(64-k)} + # Write result word given starting bits + def step{b, c} = output{sub_carry{c - b, c & mr}} + def step{b, c, m} = step{b&m, c&m} + # Fast unrolled iterations + mask := mc4 + @for (k/4) { + each{step{x & mask, xo & mask, .}, submasks} + mask = advance{mask, s4} + } + # Single-step for tail + mask = mask_tail + @for (k%4) { + step{x, xo, mask} + mask = advance{mask, s1} + } + } +} + +def rep_const_bool_small_odd{(u64), 4, wv, get_swap_x, output} = { + def k = 3 + def step{x}{o, n_over} = { + b := x & base{2, cycle{64, n_over == iota{k}}} + def os = (64-k) + 1 + iota{n_over} + output{fold{|, (b<>{o,.}, os}}} + b + } + while (1) fold{step{get_swap_x{}}, 0, (-iota{k}*64)%k} +} + +# For small odd numbers: +# - permute each byte sending bit i to position k*i % 8 +# - replicate each byte by k, making position k*i contain bit i +# - mask out those bits and spread over [ k*i, k*(i+1) ) +# - ...except where it crosses words; handle this overhang separately +# If 1-byte shuffle isn't available, use 4-byte units instead +def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = { + def ov_bytes{o} = { def V = re_el{u8, W}; v := V~~o; v += v > V**0; W~~v } + def ww = width{W} + def ew = if (any_sel) 8 else 32 # width of shuffle-able elements + def ne = ww/ew; def se = 64/ew + def lanes = ww > 128 + def dup{v} = if (lanes) merge{v,v} else v + def fixed_loop{k} = { + assert{wv == k} + while (1) { + # e.g. 01234567 to 05316427 on each byte for k==3, ew==8 + xv := get_perm_x{} + # Overhang from previous 64-bit elements + def ix = 64*slice{iota{k},1} // k # bits that overhang within a word + def ib = ix // ew # byte index + def io = ew*ib + k*ix%ew # where they are in xv + def wi = split{wl, dup{tup{255, ...ib, 255, ...se+ib}}} + xo := ov_bytes{(xv & W**fold{|, 1<> (ew-k)} + # Permute and mask bytes + def step{jj, oi, ind, mask} = { + def hk = (k-1) / 2 + def getv = if (not lanes or jj==hk) ({x}=>x) + else sel_imm{[4]u64, ., 2*(jj>hk) + iota{4}%2} + def selx{x, i} = sel_imm{[128/ew]ty_u{ew}, getv{x}, i} + b := selx{xv, ind} & make{W, mask} + r := (b<>d | m<<(l-d), d} } + +def advance_spaced_mask{k, m, sh} = m<<(k-sh) | m>>sh