Get odd k/bool carry by shifting in the unified register, not after masking

This commit is contained in:
Marshall Lochbaum 2024-08-09 08:29:43 -04:00
parent 1621c7d07c
commit 13ec029d9f

View File

@ -412,25 +412,24 @@ def rep_const_bool_generic_odd{k, xp, rp, nw, m, d} = {
@for (min{k/4 - 1, nw/4}) or_adv{sm0,s4} @for (min{k/4 - 1, nw/4}) or_adv{sm0,s4}
submasks := scan{advance, tup{sm0, ...3**s1}} submasks := scan{advance, tup{sm0, ...3**s1}}
mask_tail := advance{sm0, s4} &~ sm0 mask_tail := advance{sm0, s4} &~ sm0
# Mask out carry bit
mr := u64~~1<<k - 1
while (1) { while (1) {
x := get_swap_x{} x := get_swap_x{}
os:=o; xo:=x<<k|x>>(64-k); o=xo&1; xo=(xo&~1)|os
mask = mc4 mask = mc4
# Write result word given starting bits # Write result word given starting bits
def step{b} = { def step{b, c} = output{c - b - promote{u64, c&mr != 0}}
r := (b<<k) - b def step{b, c, m} = step{b&m, c&m}
output{r | (o - promote{u64, o > 0})}
o = b>>(64-k)
}
# Fast unrolled iterations # Fast unrolled iterations
@for (k/4) { @for (k/4) {
xm := x & mask each{step{x & mask, xo & mask, .}, submasks}
each{{mm} => step{xm & mm}, submasks}
mask = advance{mask, s4} mask = advance{mask, s4}
} }
# Single-step for tail # Single-step for tail
mask = mask_tail mask = mask_tail
@for (k%4) { @for (k%4) {
step{x & mask} step{x, xo, mask}
mask = advance{mask, s1} mask = advance{mask, s1}
} }
} }
@ -637,8 +636,10 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens} swap_masks := each{{l} => selV{swap_data, mkV{l+iV%l}}, swap_lens}
# Every-k-bits mask, same as before # Every-k-bits mask, same as before
{m, d} := unaligned_spaced_mask_mod{wv} {m, d} := unaligned_spaced_mask_mod{wv}
mask := V~~make{W, m, m>>d|m<<(wv-d)} mask := make{W, m, m>>d|m<<(wv-d)}
mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv mask_sh := d+d; if (mask_sh >= wv) mask_sh-= wv
# Mask out carry bit
mr := [4]u32~~W**(u64~~1<<wv - 1)
# State # State
xv:V = V**0; o:=W**0 xv:V = V**0; o:=W**0
i:usz = 0 i:usz = 0
@ -650,15 +651,17 @@ def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
} }
each{swap_step, swap_lens, swap_masks} each{swap_step, swap_lens, swap_masks}
xv = perm_x{xv} xv = perm_x{xv}
xw := W~~xv
def vrot1{x} = vshl{x, x, vcount{type{x}}-1}
w1 := W**1
os:=o; xo:=xw<<wv|vrot1{xw>>(64-wv)}; o=xo&w1; xo=(xo&~w1)|os
# Write wv vectors based on that # Write wv vectors based on that
@for (wv) { @for (wv) {
b := W~~(xv & mask) b := xw & mask
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
rv:= V~~((b<<wv) - b)
# Handle overhang here; won't fit in a single vector # Handle overhang here; won't fit in a single vector
po:= o; o = b>>(64-wv) c := xo & mask; cu := [4]u32~~c
ro:= [4]u32~~vshl{po, o, 1} output{V~~(W~~(cu + (mr&cu > [4]u32**0)) - b)}
output{rv | V~~(ro + (ro > [4]u32**0))} mask = (mask << (wv - mask_sh)) | (mask >> mask_sh)
} }
} }
} }