Get odd k/bool carry by shifting in the unified register, not after masking

This commit is contained in:
Marshall Lochbaum 2024-08-09 08:29:43 -04:00
parent 1621c7d07c
commit 13ec029d9f

View File

@ -412,25 +412,24 @@ def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
@for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
submasks := scan{advance, tup{sm0, ...3**s1}}
mask_tail := advance{sm0, s4} &~ sm0
# Mask out carry bit
mr := u64~~1<<k - 1
while (1) {
x := get_swap_x{}
os:=o; xo:=x<<k|x>>(64-k); o=xo&1; xo=(xo&~1)|os
mask = mc4
# Write result word given starting bits
def step{b} = {
r := (b<<k) - b
output{r | (o - promote{u64, o > 0})}
o = b>>(64-k)
}
def step{b, c} = output{c - b - promote{u64, c&mr != 0}}
def step{b, c, m} = step{b&m, c&m}
# Fast unrolled iterations
@for (k/4) {
xm := x & mask
each{{mm} => step{xm & mm}, submasks}
each{step{x & mask, xo & mask, .}, submasks}
mask = advance{mask, s4}
}
# Single-step for tail
mask = mask_tail
@for (k%4) {
step{x & mask}
step{x, xo, mask}
mask = advance{mask, s1}
}
}
@ -637,8 +636,10 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens}
# Every-k-bits mask, same as before
{m, d} := unaligned_spaced_mask_mod{wv}
mask := V~~make{W, m, m>>d|m<<(wv-d)}
mask := make{W, m, m>>d|m<<(wv-d)}
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
# Mask out carry bit
mr := [4]u32~~W**(u64~~1<<wv - 1)
# State
xv:V = V**0; o:=W**0
i:usz = 0
@ -650,15 +651,17 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
}
each{swap_step, swap_lens, swap_masks}
xv = perm_x{xv}
xw := W~~xv
def vrot1{x} = vshl{x, x, vcount{type{x}}-1}
w1 := W**1
os:=o; xo:=xw<<wv|vrot1{xw>>(64-wv)}; o=xo&w1; xo=(xo&~w1)|os
# Write wv vectors based on that
@for (wv) {
b := W~~(xv & mask)
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
rv:= V~~((b<<wv) - b)
b := xw & mask
# Handle overhang here; won't fit in a single vector
po:= o; o = b>>(64-wv)
ro:= [4]u32~~vshl{po, o, 1}
output{rv | V~~(ro + (ro > [4]u32**0))}
c := xo & mask; cu := [4]u32~~c
output{V~~(W~~(cu + (mr&cu > [4]u32**0)) - b)}
mask = (mask << (wv - mask_sh)) | (mask >> mask_sh)
}
}
}