This commit is contained in:
dzaima 2023-04-08 20:04:23 +03:00
parent 6a0385b44b
commit 52dc05f228
2 changed files with 144 additions and 29 deletions

View File

@ -152,6 +152,32 @@ static B scan_lt(B x, u64 p, usz ia) {
decG(x); return r;
}
static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
assert(xe!=el_bit && elNum(xe));
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
#if SINGELI_AVX2
switch(xe) { default:UD;
case el_i8: { if (!q_fi32(r0) || simd_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
case el_i16: { if (!q_fi32(r0) || simd_scan_plus_i16_i32(i16any_ptr(x), r0, rp, ia)!=ia) goto cs_i16_f64; decG(x); return r; }
case el_i32: { if (!q_fi32(r0) || simd_scan_plus_i32_i32(i32any_ptr(x), r0, rp, ia)!=ia) goto cs_i32_f64; decG(x); return r; }
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
}
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); simd_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); simd_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
#else
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i32 && q_fi32(r0)) { i32* xp=i32any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
base:;
decG(r);
f64* rp2; r = m_f64arrv(&rp2, ia); rp = rp2;
x = toF64Any(x);
goto res_float;
#endif
}
B scan_c1(Md1D* d, B x) { B f = d->f;
if (isAtm(x) || RNK(x)==0) thrM("`: Argument cannot have rank 0");
ur xr = RNK(x);
@ -169,13 +195,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
if (rtid==n_lt) return scan_lt(x, 0, ia); // <
goto base;
}
if (rtid==n_add) { // +
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_f64) { f64* xp=f64any_ptr(x); f64 c=0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
}
if (rtid==n_add) return scan_plus(0, x, xe, ia); // +
if (rtid==n_floor) return scan_min_num(x, xe, ia); // ⌊
if (rtid==n_ceil ) return scan_max_num(x, xe, ia); // ⌈
if (rtid==n_ne) { // ≠
@ -226,8 +246,6 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
if (rtid==n_ceil ) return scan2_max_num(w, x, xe, ia); // ⌈
if (rtid==n_add) { // +
if (xe==el_f64) { f64 c=o2fG(w); f64* rp; B r=m_f64arrv(&rp, ia); f64* xp=f64any_ptr(x); for (usz i=0; i<ia; i++) { c+= xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_bit) {
if (!q_i64(w)) goto base;
i64 wv = o2i64G(w);
@ -236,13 +254,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
return wv==0? t : C2(add, w, t);
}
if (!q_i32(w) || !elInt(xe)) goto base;
i32 c = o2iG(w);
i32* rp; B r = m_i32arrv(&rp, ia);
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
UD;
if (isF64(w) && elInt(xe)) return scan_plus(o2fG(w), x, xe, ia);
}
if (rtid==n_ne) { // ≠

View File

@ -3,6 +3,7 @@ include './sse'
include './avx'
include './avx2'
include './mask'
include './f64'
def sel8{v, t} = sel{[16]u8, v, make{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
@ -16,6 +17,12 @@ def spread{a:VT} = {
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
# Set all elements with the last element of the input
def toLast{n:VT} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7}
else shuf{[4]u64, n, 4b3333}
}
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
def step = 256/width{T}
def V = [step]T
@ -31,8 +38,7 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
def last{v, p} = op{pre{v}, p}
def scan{v, p} = {
n:= last{v, p}
p = (if (width{T}<=32) sel{[8]i32, spread{n}, [8]i32**7}
else shuf{[4]u64, n, 4b3333})
p = toLast{n}
n
}
scan_loop{T, init, x, r, len, scan, last}
@ -73,20 +79,23 @@ export{'avx2_scan_min_i8', avx2_scan_idem_id{i8 , min}}; export{'avx2_scan_max_
export{'avx2_scan_min_i16', avx2_scan_idem_id{i16, min}}; export{'avx2_scan_max_i16', avx2_scan_idem_id{i16, max}}
export{'avx2_scan_min_i32', avx2_scan_idem_id{i32, min}}; export{'avx2_scan_max_i32', avx2_scan_idem_id{i32, max}}
# Assumes identity is 0
def scan_assoc{op, a:T} = {
# Within each lane, scan using shifts by powers of 2
def w = elwidth{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
}
def scan_plus = bind{scan_assoc, +}
# Associative scan
fn avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
# Prefix op on entire AVX register
def pre{a} = {
# Within each lane, scan using shifts by powers of 2.
# Assumes identity is 0.
def w = width{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
}
scan_post{T, init, x, r, len, op, pre}
scan_post{T, init, x, r, len, op, scan_plus}
}
export{'avx2_scan_pluswrap_u8', avx2_scan_assoc_0{u8 , +}}
export{'avx2_scan_pluswrap_u16', avx2_scan_assoc_0{u16, +}}
@ -151,3 +160,97 @@ fn avx2_bcs{T}(x:*u64, r:*T, l:u64) : void = {
export{'avx2_bcs8', avx2_bcs{i8}}
export{'avx2_bcs16', avx2_bcs{i16}}
export{'avx2_bcs32', avx2_bcs{i32}}
def addChk{a:T, b:T} = {
mem:*T = tup{a}
def bad = emit{u1, '__builtin_add_overflow', a, b, mem}
tup{bad, load{mem}}
}
def addChk{a:T, b:T & T==f64} = tup{0, a+b}
def widenFull{E, xs} = {
merge{...each{{x:X} => {
def n = vcount{X}
def tb = width{E} * n
if (tb<=arch_defvw) tup{widen{[n]E, x}}
else if (1) {
assert{tb == 2*arch_defvw}
tup{
widen{[n/2]E, half{x,0}},
widen{[n/2]E, half{x,1}}
}
}
}, xs}}
}
def floor{x & knum{x}} = x - (x%1)
def maxabsval{T & issigned{T}} = -minvalue{T}
def maxsafeint{T & issigned{T}} = maxvalue{T}
def maxsafeint{T==f64} = 1<<53
def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
def bulk = arch_defvw/b
def wd = (X!=R) & (width{X}<32) # whether to widen the working copy one size
def WE = tern{wd, ty_dbl{X}, X} # working copy element type
# maxFastA: max absolute value for accumulator
# maxFastE: max absolute value for vector elements (not used if ~wd)
def maxFastE = if (wd) maxabsval{X} else maxabsval{X}/(bulk*tern{R==f64, 1, 4}) # 4 to give maxFastA some range
def maxFastA = maxsafeint{R} - maxFastE*bulk
if (R!=f64) { def m = maxFastA + maxFastE*bulk; assert{m<=maxvalue{R}}; assert{-m>=minvalue{R}} }
i:u64 = 0
cv:= [arch_defvw/width{R}]R ** c
if (R==f64 and c != floor{c}) goto{'end'}
while (1) {
def ctmp = extract{cv,0}
if (max{ctmp,-ctmp} >= tern{R==f64, cast_i{f64, i64~~maxFastA}, maxFastA}) goto{'end'}
i2:= i+bulk
if (i2>len) goto{'end'}
def cx0 = tup{load{*[bulk]X~~(x+i)}}
def cx = if(wd) widenFull{WE,cx0} else cx0
if (~wd) { # within-vector overflow check; widening gives range space for this to not happen
if (rare{homAny{tree_fold{|, each{{c:T} => absu{c} >= (ty_u{T}**maxFastE), cx}}}}) goto{'end'}
}
def s0 = each{scan_plus, cx}
def s1 = {
if (tuplen{s0}==1) s0
else { def {v0,v1}=s0; tup{v0,v1+toLast{v0}} }
}
def cr = eachx{+, widenFull{R, s1}, cv}
cv = toLast{tupsel{-1, cr}}
each{{c:T} => assert{T==type{cv}}, cr}
assert{vcount{type{cv}} * tuplen{cr} == bulk}
each{{c:T, j} => store{*T~~(r+i), j, c}, cr, iota{tuplen{cr}}}
i = i2
}
setlabel{'end'}
c = extract{cv, vcount{type{cv}}-1}
while (i < len) {
def {b,n} = addChk{c, promote{R,load{x,i}}}
if (rare{b}) return{i}
store{r, i, n}
c = n
++i
}
len
}
fn simd_plus_scanG{X, b, R}(x:*X, c:R, r:*R, len:u64) : void = simd_plus_scan{X,b,R}{x, c, r, len}
fn simd_plus_scanC{X, b, R}(x:*X, c:R, r:*R, len:u64) : u64 = simd_plus_scan{X,b,R}{x, c, r, len}
export{'simd_scan_plus_i8_i32', simd_plus_scanC{i8, 16, i32}}
export{'simd_scan_plus_i16_i32', simd_plus_scanC{i16, 16, i32}}
export{'simd_scan_plus_i32_i32', simd_plus_scanC{i32, 32, i32}}
export{'simd_scan_plus_i16_f64', simd_plus_scanG{i16, 32, f64}}
export{'simd_scan_plus_i32_f64', simd_plus_scanG{i32, 32, f64}}