diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index d7fead2f..2e525c3c 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -337,6 +337,34 @@ def avx2_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = { m = m>>d4 | m<<(l-d4) } } +def avx2_loop_with_loose_mask{xp, rp, nw, l, id, scan_words, propagate, fix_carry, apply_carry} = { + assert{l >= 64} + def V = [4]u64 + c := V**id # carry, 0 or 1 + q := -make{V, 64*iota{4}} # distance to next row boundary + def q_mod{} = { q = blend_top{q,q+V**l, q} } + o:u64 = 256; while (o>l) { o-=l; q_mod{} } + @for_masked{4} (x in tup{V, xp}, r in tup{V, rp} over nw) { + # Get mask; <=1 bit per word + m:= V**1 << q + q-= V**o; q_mod{} + # Within-word scan and carry info + ml:= m - V**1 + {s, k}:= scan_words{x, m, ml} + # Propagate carries and adjust result + p:= propagate{k, c} + t:= blend{shuf{V, p, 3,0,1,2}, c, 1,0,0,0} + r = apply_carry{s, -fix_carry{t}, ml} + c = shuf{V, p, 4**3} + } +} +def avx2_loop_with_loose_mask{...a={xp, rp, nw, l, id, scan_words}, apply_carry} = { + def passthrough{k, c} = { + def bl{b,a} = blend_top{b,a, b} + bl{make_scan_idem{f64, bl}{k}, c} # Can't be -1 now + } + avx2_loop_with_loose_mask{...a, passthrough, {k}=>k, apply_carry} +} fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { def qand = not id @@ -375,31 +403,17 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { a >> 63} # new c }} } - } else if (l < 160) { + } else if (l < (if (hasarch{'AVX2'}) 256 else 160)) { if (hasarch{'AVX2'}) { - assert{l >= 64} - def V = [4]u64 - c := V**id # carry, 0 or 1 - v1:= V**1; xk:= c - v1 - q := -make{V, 64*iota{4}} # distance to next row boundary - def q_mod{} = { q = blend_top{q,q+V**l, q} } - o:u64 = 256; while (o>l) { o-=l; q_mod{} } - @for_masked{4} (x in tup{V, src}, r in tup{V, dst} over nw) { - # Get mask; <=1 bit per word - m:= v1 << q - q-= V**o; q_mod{} - # Within-word scan and carry info - r = (if (qand) x &~ ((x+v1) & (x+m)) - else x | ((-x) &~ (x-m))) - p:= (if (qand) x&~m else x|m) == xk - k:= r>>63 | p # Carry of 0 or 1, but -1 to propagate previous - # Propagate carries and adjust result - def bl{b,a} = blend_top{b,a, b} - k = bl{make_scan_idem{f64, bl}{k}, c} # Can't be -1 now - t:= blend{shuf{V, k, 3,0,1,2}, c, 1,0,0,0} - r = apply_carry{r, -t, m-v1} - c = shuf{V, k, 4**3} + def scan_words{x:V, m:V, _} = { + s:= (if (qand) x &~ ((x+V**1) & (x+m)) + else x | ((-x) &~ (x-m))) + p:= (if (qand) x&~m == ~V**0 + else x| m == V**0) + k:= s>>63 | p # Carry of 0 or 1, but -1 to propagate previous + tup{s, k} } + avx2_loop_with_loose_mask{src, dst, nw, l, id, scan_words, apply_carry} } else { q:usz = 0 # distance to next row boundary c:u64 = id # carry @@ -443,6 +457,13 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = { fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { def scan_word = prefix_byshift{^, <<} + def scan_words = { + def vec_prefix_byshift{op, sh} = { + def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v + {v:T} => pre{v, 1} + } + vec_prefix_byshift{^, <<} + } assert{l > 0} nw := cdiv{nl, 64} if (l < 64) { @@ -455,11 +476,7 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { } } else if (hasarch{'AVX2'}) { def scan_words{x, m} = { - def vec_prefix_byshift{op, sh} = { - def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v - {v:T} => pre{v, 1} - } - s:= vec_prefix_byshift{^, <<}{x} + s:= scan_words{x} b:= s<<1 & m # last bit of previous row s ^ (b<>63 | (V**(1<<63) &~ ml) # Top bit 1 to stop, so 0 is identity + tup{s, k} + } + def propagate{k:V, c:V} = { + def bl{b,a} = blend_top{a^b,b, b} + k = bl{k, vec_shift_right_128{k, 1}} + k = bl{k, shuf{V, blend{V**0, k, 0,1,0,1}, 0,0,1,1}} + bl{k, c} + } + def fix_carry{t:V} = t & V**1 + def apply_carry{r, a, ml} = r ^ (a & ml) + avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, propagate, fix_carry, apply_carry} } else { i :usz = 0 # row bit index iw:usz = 0 # starting word @@ -491,7 +524,6 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = { } fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = { - def scan_word = prefix_byshift{^, <<} assert{l > 0} nw := cdiv{nl, 64} if (l < 64) { @@ -509,6 +541,13 @@ fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = { (c & f) | (b<>63 | (m == V**0)} + } + def apply_carry{r, a, ml} = r | (a & ml) + avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, apply_carry} } else { i :usz = 0 # row bit index wn:usz = 0 # starting word of next row