tiny improvements to plus-scan

This commit is contained in:
dzaima 2023-04-09 12:09:52 +03:00
parent 9951f20751
commit b66f628cff
3 changed files with 41 additions and 10 deletions

View File

@ -227,6 +227,30 @@ def forNZ{vars,begin,end,block} = {
++i
}
}
def forUnroll{exp,unr}{vars,begin,end,block} = {
i:u64 = begin
while ((i+unr) <= end) {
exec{each{{j}=>i+j, iota{unr}}, vars, block}
i+= unr
}
if (unr==2) { if (i!=end) exec{tup{i}, vars, block} }
else if (unr>1) {
if (exp) {
def stop = makelabel{}
each{{j} => {
if (i+j >= end) goto{stop}
exec{tup{i+j}, vars, block}
}, iota{unr}}
setlabel{stop}
} else {
@for(j from i to end) exec{tup{j}, vars, block}
}
}
}
def forXUnroll{unr}{vars,begin,end,block} = {
@forUnroll{unr}(is from begin to end) each{{i} => exec{i, vars, block}, is}
}
def tree_fold{F, x} = {
def h = tuplen{x}>>1

View File

@ -1,5 +1,6 @@
def ceil{x:f64} = emit{f64, 'ceil', x}
def floor{x:f64} = emit{f64, 'floor', x}
def abs{x:f64} = emit{f64, 'fabs', x}
def NaN = 0.0/0.0
def isNaN{x:f64} = x!=x

View File

@ -205,12 +205,15 @@ def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
i:u64 = 0
cv:= [arch_defvw/width{R}]R ** c
if (R==f64 and c != floor{c}) goto{'end'}
while (1) {
def ctmp = extract{cv,0}
if (max{ctmp,-ctmp} >= tern{R==f64, cast_i{f64, i64~~maxFastA}, maxFastA}) goto{'end'}
if (R==f64) { if (rare{abs{extract{cv,0}} >= cast_i{f64, i64~~maxFastA}}) goto{'end'} }
else { if (rare{extract{absu{half{cv,0}},0} > maxFastA}) goto{'end'} }
i2:= i+bulk
if (i2>len) goto{'end'}
def cx0 = tup{load{*[bulk]X~~(x+i)}}
def cx = if(wd) widenFull{WE,cx0} else cx0
if (~wd) { # within-vector overflow check; widening gives range space for this to not happen
@ -233,15 +236,18 @@ def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
each{{c:T, j} => store{*T~~(r+i), j, c}, cr, iota{tuplen{cr}}}
i = i2
}
setlabel{'end'}
c = extract{cv, vcount{type{cv}}-1}
while (i < len) {
def {b,n} = addChk{c, promote{R,load{x,i}}}
if (rare{b}) return{i}
store{r, i, n}
c = n
++i
setlabel{'end'}
c = extract{cv, 0}
@forUnroll{1,1} (js from i to len) {
def vs = eachx{load, x, js}
each{{j, v} => {
def {b,n} = addChk{c, promote{R, v}}
if (rare{b}) return{j}
store{r, j, n}
c = n
}, js, vs}
}
len
}