SIMD +`
This commit is contained in:
parent
6a0385b44b
commit
52dc05f228
@ -152,6 +152,32 @@ static B scan_lt(B x, u64 p, usz ia) {
|
||||
decG(x); return r;
|
||||
}
|
||||
|
||||
static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
|
||||
assert(xe!=el_bit && elNum(xe));
|
||||
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
|
||||
#if SINGELI_AVX2
|
||||
switch(xe) { default:UD;
|
||||
case el_i8: { if (!q_fi32(r0) || simd_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
|
||||
case el_i16: { if (!q_fi32(r0) || simd_scan_plus_i16_i32(i16any_ptr(x), r0, rp, ia)!=ia) goto cs_i16_f64; decG(x); return r; }
|
||||
case el_i32: { if (!q_fi32(r0) || simd_scan_plus_i32_i32(i32any_ptr(x), r0, rp, ia)!=ia) goto cs_i32_f64; decG(x); return r; }
|
||||
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||
}
|
||||
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
|
||||
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); simd_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); simd_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
|
||||
#else
|
||||
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i32 && q_fi32(r0)) { i32* xp=i32any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_f64) { res_float:; f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||
base:;
|
||||
decG(r);
|
||||
f64* rp2; r = m_f64arrv(&rp2, ia); rp = rp2;
|
||||
x = toF64Any(x);
|
||||
goto res_float;
|
||||
#endif
|
||||
}
|
||||
|
||||
B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
if (isAtm(x) || RNK(x)==0) thrM("`: Argument cannot have rank 0");
|
||||
ur xr = RNK(x);
|
||||
@ -169,13 +195,7 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
|
||||
if (rtid==n_lt) return scan_lt(x, 0, ia); // <
|
||||
goto base;
|
||||
}
|
||||
if (rtid==n_add) { // +
|
||||
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
|
||||
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
|
||||
if (xe==el_f64) { f64* xp=f64any_ptr(x); f64 c=0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
|
||||
}
|
||||
if (rtid==n_add) return scan_plus(0, x, xe, ia); // +
|
||||
if (rtid==n_floor) return scan_min_num(x, xe, ia); // ⌊
|
||||
if (rtid==n_ceil ) return scan_max_num(x, xe, ia); // ⌈
|
||||
if (rtid==n_ne) { // ≠
|
||||
@ -226,8 +246,6 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
|
||||
if (rtid==n_ceil ) return scan2_max_num(w, x, xe, ia); // ⌈
|
||||
|
||||
if (rtid==n_add) { // +
|
||||
if (xe==el_f64) { f64 c=o2fG(w); f64* rp; B r=m_f64arrv(&rp, ia); f64* xp=f64any_ptr(x); for (usz i=0; i<ia; i++) { c+= xp[i]; rp[i]=c; } decG(x); return r; }
|
||||
|
||||
if (xe==el_bit) {
|
||||
if (!q_i64(w)) goto base;
|
||||
i64 wv = o2i64G(w);
|
||||
@ -236,13 +254,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
|
||||
return wv==0? t : C2(add, w, t);
|
||||
}
|
||||
|
||||
if (!q_i32(w) || !elInt(xe)) goto base;
|
||||
i32 c = o2iG(w);
|
||||
i32* rp; B r = m_i32arrv(&rp, ia);
|
||||
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
|
||||
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
|
||||
UD;
|
||||
if (isF64(w) && elInt(xe)) return scan_plus(o2fG(w), x, xe, ia);
|
||||
}
|
||||
|
||||
if (rtid==n_ne) { // ≠
|
||||
|
||||
@ -3,6 +3,7 @@ include './sse'
|
||||
include './avx'
|
||||
include './avx2'
|
||||
include './mask'
|
||||
include './f64'
|
||||
|
||||
def sel8{v, t} = sel{[16]u8, v, make{[32]i8, t}}
|
||||
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
|
||||
@ -16,6 +17,12 @@ def spread{a:VT} = {
|
||||
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
||||
}
|
||||
|
||||
# Set all elements with the last element of the input
|
||||
def toLast{n:VT} = {
|
||||
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7}
|
||||
else shuf{[4]u64, n, 4b3333}
|
||||
}
|
||||
|
||||
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
|
||||
def step = 256/width{T}
|
||||
def V = [step]T
|
||||
@ -31,8 +38,7 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
|
||||
def last{v, p} = op{pre{v}, p}
|
||||
def scan{v, p} = {
|
||||
n:= last{v, p}
|
||||
p = (if (width{T}<=32) sel{[8]i32, spread{n}, [8]i32**7}
|
||||
else shuf{[4]u64, n, 4b3333})
|
||||
p = toLast{n}
|
||||
n
|
||||
}
|
||||
scan_loop{T, init, x, r, len, scan, last}
|
||||
@ -73,20 +79,23 @@ export{'avx2_scan_min_i8', avx2_scan_idem_id{i8 , min}}; export{'avx2_scan_max_
|
||||
export{'avx2_scan_min_i16', avx2_scan_idem_id{i16, min}}; export{'avx2_scan_max_i16', avx2_scan_idem_id{i16, max}}
|
||||
export{'avx2_scan_min_i32', avx2_scan_idem_id{i32, min}}; export{'avx2_scan_max_i32', avx2_scan_idem_id{i32, max}}
|
||||
|
||||
# Assumes identity is 0
|
||||
def scan_assoc{op, a:T} = {
|
||||
# Within each lane, scan using shifts by powers of 2
|
||||
def w = elwidth{T}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
|
||||
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
|
||||
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
||||
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
|
||||
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
|
||||
}
|
||||
def scan_plus = bind{scan_assoc, +}
|
||||
|
||||
# Associative scan
|
||||
fn avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
|
||||
# Prefix op on entire AVX register
|
||||
def pre{a} = {
|
||||
# Within each lane, scan using shifts by powers of 2.
|
||||
# Assumes identity is 0.
|
||||
def w = width{T}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
|
||||
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
|
||||
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
||||
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
|
||||
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
|
||||
}
|
||||
scan_post{T, init, x, r, len, op, pre}
|
||||
|
||||
scan_post{T, init, x, r, len, op, scan_plus}
|
||||
}
|
||||
export{'avx2_scan_pluswrap_u8', avx2_scan_assoc_0{u8 , +}}
|
||||
export{'avx2_scan_pluswrap_u16', avx2_scan_assoc_0{u16, +}}
|
||||
@ -151,3 +160,97 @@ fn avx2_bcs{T}(x:*u64, r:*T, l:u64) : void = {
|
||||
export{'avx2_bcs8', avx2_bcs{i8}}
|
||||
export{'avx2_bcs16', avx2_bcs{i16}}
|
||||
export{'avx2_bcs32', avx2_bcs{i32}}
|
||||
|
||||
|
||||
|
||||
def addChk{a:T, b:T} = {
|
||||
mem:*T = tup{a}
|
||||
def bad = emit{u1, '__builtin_add_overflow', a, b, mem}
|
||||
tup{bad, load{mem}}
|
||||
}
|
||||
def addChk{a:T, b:T & T==f64} = tup{0, a+b}
|
||||
|
||||
def widenFull{E, xs} = {
|
||||
merge{...each{{x:X} => {
|
||||
def n = vcount{X}
|
||||
def tb = width{E} * n
|
||||
if (tb<=arch_defvw) tup{widen{[n]E, x}}
|
||||
else if (1) {
|
||||
assert{tb == 2*arch_defvw}
|
||||
tup{
|
||||
widen{[n/2]E, half{x,0}},
|
||||
widen{[n/2]E, half{x,1}}
|
||||
}
|
||||
}
|
||||
}, xs}}
|
||||
}
|
||||
|
||||
def floor{x & knum{x}} = x - (x%1)
|
||||
def maxabsval{T & issigned{T}} = -minvalue{T}
|
||||
def maxsafeint{T & issigned{T}} = maxvalue{T}
|
||||
def maxsafeint{T==f64} = 1<<53
|
||||
|
||||
def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
|
||||
def bulk = arch_defvw/b
|
||||
|
||||
def wd = (X!=R) & (width{X}<32) # whether to widen the working copy one size
|
||||
def WE = tern{wd, ty_dbl{X}, X} # working copy element type
|
||||
|
||||
# maxFastA: max absolute value for accumulator
|
||||
# maxFastE: max absolute value for vector elements (not used if ~wd)
|
||||
def maxFastE = if (wd) maxabsval{X} else maxabsval{X}/(bulk*tern{R==f64, 1, 4}) # 4 to give maxFastA some range
|
||||
def maxFastA = maxsafeint{R} - maxFastE*bulk
|
||||
|
||||
if (R!=f64) { def m = maxFastA + maxFastE*bulk; assert{m<=maxvalue{R}}; assert{-m>=minvalue{R}} }
|
||||
|
||||
i:u64 = 0
|
||||
cv:= [arch_defvw/width{R}]R ** c
|
||||
if (R==f64 and c != floor{c}) goto{'end'}
|
||||
while (1) {
|
||||
def ctmp = extract{cv,0}
|
||||
if (max{ctmp,-ctmp} >= tern{R==f64, cast_i{f64, i64~~maxFastA}, maxFastA}) goto{'end'}
|
||||
i2:= i+bulk
|
||||
if (i2>len) goto{'end'}
|
||||
def cx0 = tup{load{*[bulk]X~~(x+i)}}
|
||||
def cx = if(wd) widenFull{WE,cx0} else cx0
|
||||
if (~wd) { # within-vector overflow check; widening gives range space for this to not happen
|
||||
if (rare{homAny{tree_fold{|, each{{c:T} => absu{c} >= (ty_u{T}**maxFastE), cx}}}}) goto{'end'}
|
||||
}
|
||||
|
||||
def s0 = each{scan_plus, cx}
|
||||
|
||||
def s1 = {
|
||||
if (tuplen{s0}==1) s0
|
||||
else { def {v0,v1}=s0; tup{v0,v1+toLast{v0}} }
|
||||
}
|
||||
|
||||
def cr = eachx{+, widenFull{R, s1}, cv}
|
||||
cv = toLast{tupsel{-1, cr}}
|
||||
|
||||
each{{c:T} => assert{T==type{cv}}, cr}
|
||||
assert{vcount{type{cv}} * tuplen{cr} == bulk}
|
||||
|
||||
each{{c:T, j} => store{*T~~(r+i), j, c}, cr, iota{tuplen{cr}}}
|
||||
i = i2
|
||||
}
|
||||
setlabel{'end'}
|
||||
|
||||
c = extract{cv, vcount{type{cv}}-1}
|
||||
while (i < len) {
|
||||
def {b,n} = addChk{c, promote{R,load{x,i}}}
|
||||
if (rare{b}) return{i}
|
||||
store{r, i, n}
|
||||
c = n
|
||||
++i
|
||||
}
|
||||
len
|
||||
}
|
||||
fn simd_plus_scanG{X, b, R}(x:*X, c:R, r:*R, len:u64) : void = simd_plus_scan{X,b,R}{x, c, r, len}
|
||||
fn simd_plus_scanC{X, b, R}(x:*X, c:R, r:*R, len:u64) : u64 = simd_plus_scan{X,b,R}{x, c, r, len}
|
||||
|
||||
export{'simd_scan_plus_i8_i32', simd_plus_scanC{i8, 16, i32}}
|
||||
export{'simd_scan_plus_i16_i32', simd_plus_scanC{i16, 16, i32}}
|
||||
export{'simd_scan_plus_i32_i32', simd_plus_scanC{i32, 32, i32}}
|
||||
|
||||
export{'simd_scan_plus_i16_f64', simd_plus_scanG{i16, 32, f64}}
|
||||
export{'simd_scan_plus_i32_f64', simd_plus_scanG{i32, 32, f64}}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user