Extend loose_mask system to SSE and NEON (not always fast)

This commit is contained in:
Marshall Lochbaum 2025-02-25 14:48:45 -05:00
parent fb413b1966
commit b5ec2f5fa8

View File

@ -316,6 +316,54 @@ export{'si_scan_plus_i32_f64', plus_scanG{i32, f64}}
# Row-wise boolean scan
def rotate_right{x:[vl]_} = shuf{x, (iota{vl}-1)%vl}
def blend_first{x:V=[vl]_, y:V} = blend{x, y, 0 < iota{vl}}
# Create masks of the given type with spacing l>=64
def loose_mask_gen{(u64), l} = {
q:usz = 0 # distance to next row boundary
{} => {
b:= q<64 # whether there's a boundary
p:= q%64 # its position
q-= 64 - (l &- b)
promote{u64, b} << p
}
}
def loose_mask_gen{V=[vl]T, l} = { # Slow, for ≠` only
def get = loose_mask_gen{T, l}
{} => make{V, @collect (vl) get{}}
}
def loose_mask_gen{V=[vl](u64), l if hasarch{'SSSE3'}} = {
# Shuffles can substitute for variable shifts, in a pinch
assert{l < 256}
def I = re_el{i8, V}; def [il]_ = I
def U = [il]u8
q := make{I, (-128) ^ (-8*iota{il})} # distance to next row boundary, -128
l8:= cast_i{u8, l}; vl:= I**i8~~l8
def q_mod{} = { q+= vl &~ I~~(q < I**0) }
q_mod{}
o:u8 = width{V}; while (o>l8) { o-=l8; q_mod{} }
oo:= I**i8~~(o - 128)
s := make{U, 1<<(iota{il}%8)}
{} => {
m:= shuf{s, q & I**7} & (q < I**(8-128))
q+= (vl & I~~(q < oo)) - I~~U**o
V~~m
}
}
def has_vecshift = hasarch{'AVX2'} or hasarch{'AARCH64'}
def loose_mask_gen{V=[vl](u64), l if has_vecshift} = {
q := -make{V, 64*iota{vl}} # distance to next row boundary
def q_mod{} = { q+= V**l & -(q>>63) }
def q_mod{if hasarch{'SSE4.1'}} = { q = blend_top{q,q+V**l, q} }
o:u64 = width{V}; while (o>l) { o-=l; q_mod{} }
{} => {
m:= V**1 << q; if (not hasarch{'AVX2'}) m&= q < V**64
q-= V**o; q_mod{}
m
}
}
def loop_with_unaligned_mask{x, r, nw, l, step} = {
{m, d} := unaligned_spaced_mask_mod{l}
c:u64 = 0 # carry (initial value never matters)
@ -338,41 +386,39 @@ def vec_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = {
s := scan_words{x, m}
# Each result word can be modified based on top bit of previous
t := -(s>>63)
pc:= c
c = shuf{t, (iota{vl}-1)%vl} # Rotate right one
b := (if (vl==2) zip{pc, t, 0} else blend{c, pc, 0==iota{vl}})
pc:= c; c = rotate_right{t}
b := (if (vl==2) zip{pc, t, 0} else blend_first{pc, c})
# Carry applies to bits below any mask bit
r = apply_carry{s, b, (m-V**1)&~m}
m = m>>d4 | m<<(l-d4)
}
}
def avx2_loop_with_loose_mask{xp, rp, nw, l, id, scan_words, propagate, fix_carry, apply_carry} = {
def vec_loop_with_loose_mask{xp, rp, nw, l, id, scan_words, propagate, fix_carry, apply_carry} = {
assert{l >= 64}
def V = [4]u64
def vl = arch_defvw / 64
def V = [vl]u64
def get_m = loose_mask_gen{V, l}
c := V**id # carry, 0 or 1
q := -make{V, 64*iota{4}} # distance to next row boundary
def q_mod{} = { q = blend_top{q,q+V**l, q} }
o:u64 = 256; while (o>l) { o-=l; q_mod{} }
@for_masked{4} (x in tup{V, xp}, r in tup{V, rp} over nw) {
@for_masked{vl} (x in tup{V, xp}, r in tup{V, rp} over nw) {
# Get mask; <=1 bit per word
m:= V**1 << q
q-= V**o; q_mod{}
m:= get_m{}
# Within-word scan and carry info
ml:= m - V**1
{s, k}:= scan_words{x, m, ml}
# Propagate carries and adjust result
p:= propagate{k, c}
t:= blend{shuf{V, p, 3,0,1,2}, c, 1,0,0,0}
t:= (if (vl==2) zip{c, p, 0} else blend_first{c, rotate_right{p}})
r = apply_carry{s, -fix_carry{t}, ml}
c = shuf{V, p, 4**3}
c = shuf{V, p, vl**(vl-1)}
}
}
def avx2_loop_with_loose_mask{...a={xp, rp, nw, l, id, scan_words}, apply_carry} = {
def vec_loop_with_loose_mask{...a={xp, rp, nw, l, id, scan_words}, apply_carry} = {
def passthrough{k, c} = {
def bl{b,a} = blend_top{b,a, b}
def bl{b,a} = b ^ ((b^a) & -(b>>63))
def bl{b,a if hasarch{'SSE4.1'}} = blend_top{b,a, b}
bl{make_scan_idem{f64, bl}{k}, c} # Can't be -1 now
}
avx2_loop_with_loose_mask{...a, passthrough, {k}=>k, apply_carry}
vec_loop_with_loose_mask{...a, passthrough, {k}=>k, apply_carry}
}
fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
@ -413,7 +459,7 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
}}
}
} else if (l < (if (hasarch{'AVX2'}) 256 else 160)) {
if (hasarch{'AVX2'}) {
if (hasarch{'SSE4.1'} or hasarch{'AARCH64'}) {
def scan_words{x:V, m:V, _} = {
s:= (if (qand) x &~ ((x+V**1) & (x+m))
else x | ((-x) &~ (x-m)))
@ -422,15 +468,12 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
k:= s>>63 | p # Carry of 0 or 1, but -1 to propagate previous
tup{s, k}
}
avx2_loop_with_loose_mask{src, dst, nw, l, id, scan_words, apply_carry}
vec_loop_with_loose_mask{src, dst, nw, l, id, scan_words, apply_carry}
} else {
q:usz = 0 # distance to next row boundary
def get_m = loose_mask_gen{u64, l}
c:u64 = id # carry
@for (r in dst, x in src over nw) {
b:= q<64 # whether there's a boundary
p:= q%64 # its position
q-= 64 - (l &- b)
r = res_m1{x, c, promote{u64, b} << p}
r = res_m1{x, c, get_m{}}
c = r >> 63
}
}
@ -495,21 +538,22 @@ fn scan_rows_neq(x:*u64, r:*u64, nl:usz, l:usz) : void = {
s ^ ((c & f) | (b<<l - b))
}}
}
} else if (l < 320 and hasarch{'AVX2'}) {
} else if (has_simd and l < (if (hasarch{'SSSE3'} and not hasarch{'AVX2'}) 256 else 320)) {
def scan_words{x:V, m:V, ml:V} = {
s:= scan_words{x}
s^= -(s<<1 & m)
k:= s>>63 | (V**(1<<63) &~ ml) # Top bit 1 to stop, so 0 is identity
tup{s, k}
}
def propagate{k:V, c:V} = {
def bl{b,a} = blend_top{a^b,b, b}
k = bl{k, vec_shift_right_128{k, 1}}
k = bl{k, shuf{V, blend{V**0, k, 0,1,0,1}, 0,0,1,1}}
bl{k, c}
def propagate{k:V=[vl]_, c:V} = {
def bl{b,a} = b ^ (a &~ -(b>>63))
def bl{b,a if hasarch{'SSE4.1'}} = blend_top{a^b,b, b}
k = bl{k, vec_shift_right_128{k, 1}}
if (vl>2) k = bl{k, shuf{V, blend{V**0, k, 0,1,0,1}, 0,0,1,1}}
bl{k, c}
}
def fix_carry{t:V} = t & V**1
avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, propagate, fix_carry, apply_carry}
vec_loop_with_loose_mask{x, r, nw, l, 0, scan_words, propagate, fix_carry, apply_carry}
} else {
i :usz = 0 # row bit index
iw:usz = 0 # starting word
@ -544,12 +588,12 @@ fn scan_rows_left(x:*u64, r:*u64, nl:usz, l:usz) : void = {
(c & f) | apply_mask{x, m}
}}
}
} else if (l < 176 and hasarch{'AVX2'}) {
} else if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and l < 176) {
def scan_words{x:V, m:V, _} = {
s:= -(x & m)
tup{s, s>>63 | (m == V**0)}
}
avx2_loop_with_loose_mask{x, r, nw, l, 0, scan_words, apply_carry}
vec_loop_with_loose_mask{x, r, nw, l, 0, scan_words, apply_carry}
} else {
i :usz = 0 # row bit index
wn:usz = 0 # starting word of next row