Implement SIMD wrapping plus-scan for Replicate
This commit is contained in:
parent
c42f0fd699
commit
6ed3c18389
@ -104,6 +104,16 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if SINGELI
|
||||
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
|
||||
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
|
||||
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
|
||||
#define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
|
||||
#define PLUS_SCAN(T) avx2_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
|
||||
#else
|
||||
#define PLUS_SCAN(T) for (usz i=k; i<e; i++) js=rp[i]+=js;
|
||||
#endif
|
||||
|
||||
// Dense Where, still significantly worse than SIMD
|
||||
// Assumes modifiable DST
|
||||
#define WHERE_DENSE(SRC, DST, LEN, OFF) do { \
|
||||
@ -473,10 +483,10 @@ B slash_c1(B t, B x) {
|
||||
usz e = b<s-k? k+b : s; \
|
||||
for (usz i=k; i<e; i++) rp[i]=0; \
|
||||
while (ij<e) { rp[ij]++; ij+=xp[++j]; } \
|
||||
for (usz i=k; i<e; i++) js=rp[i]+=js; \
|
||||
PLUS_SCAN(u32) \
|
||||
if (e==s) {break;} k=e; \
|
||||
}
|
||||
i32* rp; r = m_i32arrv(&rp, s);
|
||||
u32* rp; r = m_i32arrv((i32**)&rp, s);
|
||||
if (xe == el_i8 ) { SPARSE_IND(i8 ); }
|
||||
else if (xe == el_i16) { SPARSE_IND(i16); }
|
||||
else { SPARSE_IND(i32); }
|
||||
@ -605,7 +615,7 @@ B slash_c2(B t, B w, B x) {
|
||||
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
|
||||
if (rsh) { Arr* ra=a(r); SPRNK(ra,xr); PSH(ra) = rsh; PIA(ra) = s*arr_csz(x); }
|
||||
void* xv = tyany_ptr(x);
|
||||
if (s/32 <= wia) { // Sparse case: use both types
|
||||
if (s/64 <= wia) { // Sparse case: use both types
|
||||
#define CASE(L,XT) case L: { \
|
||||
XT* xp = xv; XT* rp = rv; \
|
||||
usz b = 1<<10; \
|
||||
@ -613,8 +623,8 @@ B slash_c2(B t, B w, B x) {
|
||||
for (usz k=0, j=0, ij=wp[0]; ; ) { \
|
||||
usz e = b<s-k? k+b : s; \
|
||||
for (usz i=k; i<e; i++) rp[i]=0; \
|
||||
while (ij<e) { j++; XT sx=px; rp[ij]^=sx^(px=xp[j]); ij+=wp[j]; } \
|
||||
for (usz i=k; i<e; i++) js=rp[i]^=js; \
|
||||
while (ij<e) { j++; XT sx=px; rp[ij]+=(px=xp[j])-sx; ij+=wp[j]; } \
|
||||
PLUS_SCAN(XT) \
|
||||
if (e==s) {break;} k=e; \
|
||||
} break; }
|
||||
#define SPARSE_REP(WT) \
|
||||
|
||||
@ -10,20 +10,42 @@ def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
|
||||
def base{b,l} = { if (0==tuplen{l}) 0; else tupsel{0,l}+b*base{b,slice{l,1}} }
|
||||
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
|
||||
|
||||
# Fill last 4 bytes with last element, in each lane
|
||||
def spread{a:VT} = {
|
||||
def w = width{eltype{VT}}
|
||||
def b = w/8
|
||||
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
||||
}
|
||||
|
||||
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
|
||||
def step = 256/width{T}
|
||||
def V = [step]T
|
||||
p:= broadcast{V, init}
|
||||
xv:= *V ~~ x
|
||||
rv:= *V ~~ r
|
||||
e:= len/step
|
||||
@for (xv, rv over e) rv = scan{xv,p}
|
||||
q:= len & (step-1)
|
||||
if (q) maskstoreF{rv, maskOf{V, q}, e, scan_last{load{xv,e}, p}}
|
||||
}
|
||||
def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
|
||||
def last{v, p} = op{pre{v}, p}
|
||||
def scan{v, p} = {
|
||||
n:= last{v, p}
|
||||
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
|
||||
n
|
||||
}
|
||||
scan_loop{T, init, x, r, len, scan, last}
|
||||
}
|
||||
|
||||
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
|
||||
avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
|
||||
def w = width{T}
|
||||
|
||||
# Within each lane, scan using shifts by powers of 2. First k elements
|
||||
# when shifting by k don't need to change, so leave them alone.
|
||||
def w = width{T}
|
||||
def shift{k,l} = merge{iota{k},iota{l-k}}
|
||||
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
|
||||
# Fill last 4 bytes with last element, in each lane
|
||||
def spread{a} = {
|
||||
def b = w/8
|
||||
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
||||
}
|
||||
# Prefix op on entire AVX register
|
||||
def pre{a} = {
|
||||
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
|
||||
@ -31,19 +53,7 @@ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
|
||||
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
|
||||
}
|
||||
|
||||
def step = 256/w
|
||||
def V = [step]T
|
||||
p:= broadcast{V, id}
|
||||
xv:= *V ~~ x
|
||||
rv:= *V ~~ r
|
||||
e:= len/step
|
||||
@for (xv, rv over e) {
|
||||
n:= op{pre{xv}, p}
|
||||
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
|
||||
rv = n
|
||||
}
|
||||
q:= len & (step-1)
|
||||
if (q) maskstoreF{rv, maskOf{V, q}, e, op{pre{load{xv,e}}, p}}
|
||||
scan_post{T, id, x, r, len, op, pre}
|
||||
}
|
||||
def avx2_scan_idem{T, op} = {
|
||||
def m = 1 << (width{T}-1)
|
||||
@ -56,6 +66,25 @@ def avx2_scan_idem{T, op} = {
|
||||
'avx2_scan_min32' = avx2_scan_idem{i32, min}
|
||||
'avx2_scan_max32' = avx2_scan_idem{i32, max}
|
||||
|
||||
# Associative scan
|
||||
avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
|
||||
# Prefix op on entire AVX register
|
||||
def pre{a} = {
|
||||
# Within each lane, scan using shifts by powers of 2.
|
||||
# Assumes identity is 0.
|
||||
def w = width{T}
|
||||
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
|
||||
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
|
||||
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
||||
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
|
||||
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
|
||||
}
|
||||
scan_post{T, init, x, r, len, op, pre}
|
||||
}
|
||||
'avx2_scan_pluswrap_u8' = avx2_scan_assoc_0{u8 , +}
|
||||
'avx2_scan_pluswrap_u16' = avx2_scan_assoc_0{u16, +}
|
||||
'avx2_scan_pluswrap_u32' = avx2_scan_assoc_0{u32, +}
|
||||
|
||||
# Boolean cumulative sum
|
||||
avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
|
||||
rv:= *[8]u32~~r
|
||||
|
||||
Loading…
Reference in New Issue
Block a user