Mask out values past the end for strided +` overflow checks

This commit is contained in:
Marshall Lochbaum 2025-03-04 07:38:28 -05:00
parent 9fdeb5379a
commit a11ed3d79a

View File

@ -130,10 +130,10 @@ fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz)
else { m:= V~~(iv >= js); {x} => selI{x, v} & m } else { m:= V~~(iv >= js); {x} => selI{x, v} & m }
} }
c:= V**id{T} c:= V**id{T}
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) { @for_masked{vl} (x in tup{V, x}, r in tup{V, r}, M in 'm' over ia) {
xs:= fold{{v, i} => op{i{v}, v}, x, inds} xs:= fold{{v, i} => op{i{v}, v}, x, inds}
r = c = op{shuf{IE, c, spr}, xs} r = c = op{shuf{IE, c, spr}, xs}
check_over{x, r} # For +, infers other argument as r-x check_over{M, x, r} # For +, infers other argument as r-x
} }
} }
if (not (same{op,+} and V==[4]f64)) { if (not (same{op,+} and V==[4]f64)) {
@ -157,7 +157,7 @@ fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz)
} }
} else { } else {
# Large stride: single shift, with saved register or memory # Large stride: single shift, with saved register or memory
def op_chk{p, x} = { r:= op{p, x}; check_over{p, x, r}; r } def op_chk{M, p, x} = { r:= op{p, x}; check_over{M, p, x, r}; r }
@for (r, x over l) r = x @for (r, x over l) r = x
if (has_shuf and l<256/(wT/8)) { if (has_shuf and l<256/(wT/8)) {
def [il]IE = I def [il]IE = I
@ -169,37 +169,37 @@ fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz)
if (l == 2*vl) { o = vl; bv = ~bv } if (l == 2*vl) { o = vl; bv = ~bv }
if (o == vl) { if (o == vl) {
p:= load{*V~~x}; store{*V~~r, 0, p} p:= load{*V~~x}; store{*V~~r, 0, p}
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o} over ia-o) { @for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, M in 'm' over ia-o) {
p = rot{p} p = rot{p}
r = op_chk{bl{c, p}, x} r = op_chk{M, bl{c, p}, x}
c = p; p = r c = p; p = r
} }
} else { } else {
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r} over ia-o) { @for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r}, M in 'm' over ia-o) {
q:= rot{p} q:= rot{p}
r = op_chk{bl{c, q}, x} r = op_chk{M, bl{c, q}, x}
c = q c = q
} }
} }
} else if (same{op, +} and T<=i32 and has_simd and (has_shuf or l>=vl)) { } else if (same{op, +} and T<=i32 and has_simd and (has_shuf or l>=vl)) {
def vl = arch_defvw/wT; def V = [vl]T def vl = arch_defvw/wT; def V = [vl]T
@for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r} over ia-l) { @for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r}, M in 'm' over ia-l) {
r = op_chk{p, x} r = op_chk{M, p, x}
} }
} else { } else {
@for (r, x, p in r-l over _ from l to ia) r = op_chk{p, x} @for (r, x, p in r-l over _ from l to ia) r = op_chk{0, p, x}
} }
} }
1 1
} }
def scan_stride_assoc{op, T} = scan_stride_assoc{op, T, void, {..._}=>{}} def scan_stride_assoc{op, T} = scan_stride_assoc{op, T, void, {..._}=>{}}
def check_add_over{w:T, x:T, r:T} = { if ((w^r) & (x^r) < 0) return{0} } def check_add_over{_, w:T, x:T, r:T} = { if ((w^r) & (x^r) < 0) return{0} }
def check_add_over{w:V=[_]E, x:V, r:V} = { def check_add_over{M, w:V=[_]E, x:V, r:V} = {
o:= (if (not hasarch{'X86_64'} or width{E}<=16) any_hom{subs{r,w} != x} o:= (if (not hasarch{'X86_64'} or width{E}<=16) any_hom{M, subs{r,w} != x}
else any_top{(w^r) & (x^r)}) else any_top{M, (w^r) & (x^r)})
if (o) return{0} if (o) return{0}
} }
def check_add_over{x, r} = check_add_over{r-x, x, r} def check_add_over{M, x, r} = check_add_over{M, r-x, x, r}
export_tab{'si_scan_stride_minmax', export_tab{'si_scan_stride_minmax',
flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}} flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}}
} }