Implement SIMD wrapping plus-scan for Replicate

This commit is contained in:
Marshall Lochbaum 2022-09-20 20:24:13 -04:00
parent c42f0fd699
commit 6ed3c18389
2 changed files with 64 additions and 25 deletions

View File

@ -104,6 +104,16 @@
#endif
#endif
#if SINGELI
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
#define PLUS_SCAN(T) avx2_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
#else
#define PLUS_SCAN(T) for (usz i=k; i<e; i++) js=rp[i]+=js;
#endif
// Dense Where, still significantly worse than SIMD
// Assumes modifiable DST
#define WHERE_DENSE(SRC, DST, LEN, OFF) do { \
@ -473,10 +483,10 @@ B slash_c1(B t, B x) {
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { rp[ij]++; ij+=xp[++j]; } \
for (usz i=k; i<e; i++) js=rp[i]+=js; \
PLUS_SCAN(u32) \
if (e==s) {break;} k=e; \
}
i32* rp; r = m_i32arrv(&rp, s);
u32* rp; r = m_i32arrv((i32**)&rp, s);
if (xe == el_i8 ) { SPARSE_IND(i8 ); }
else if (xe == el_i16) { SPARSE_IND(i16); }
else { SPARSE_IND(i32); }
@ -605,7 +615,7 @@ B slash_c2(B t, B w, B x) {
void* rv = m_tyarrv(&r, 1<<xk, s, xt);
if (rsh) { Arr* ra=a(r); SPRNK(ra,xr); PSH(ra) = rsh; PIA(ra) = s*arr_csz(x); }
void* xv = tyany_ptr(x);
if (s/32 <= wia) { // Sparse case: use both types
if (s/64 <= wia) { // Sparse case: use both types
#define CASE(L,XT) case L: { \
XT* xp = xv; XT* rp = rv; \
usz b = 1<<10; \
@ -613,8 +623,8 @@ B slash_c2(B t, B w, B x) {
for (usz k=0, j=0, ij=wp[0]; ; ) { \
usz e = b<s-k? k+b : s; \
for (usz i=k; i<e; i++) rp[i]=0; \
while (ij<e) { j++; XT sx=px; rp[ij]^=sx^(px=xp[j]); ij+=wp[j]; } \
for (usz i=k; i<e; i++) js=rp[i]^=js; \
while (ij<e) { j++; XT sx=px; rp[ij]+=(px=xp[j])-sx; ij+=wp[j]; } \
PLUS_SCAN(XT) \
if (e==s) {break;} k=e; \
} break; }
#define SPARSE_REP(WT) \

View File

@ -10,20 +10,42 @@ def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def base{b,l} = { if (0==tuplen{l}) 0; else tupsel{0,l}+b*base{b,slice{l,1}} }
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
# Fill last 4 bytes with last element, in each lane
def spread{a:VT} = {
def w = width{eltype{VT}}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
def step = 256/width{T}
def V = [step]T
p:= broadcast{V, init}
xv:= *V ~~ x
rv:= *V ~~ r
e:= len/step
@for (xv, rv over e) rv = scan{xv,p}
q:= len & (step-1)
if (q) maskstoreF{rv, maskOf{V, q}, e, scan_last{load{xv,e}, p}}
}
def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
def last{v, p} = op{pre{v}, p}
def scan{v, p} = {
n:= last{v, p}
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
n
}
scan_loop{T, init, x, r, len, scan, last}
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
def w = width{T}
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def w = width{T}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
# Fill last 4 bytes with last element, in each lane
def spread{a} = {
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
# Prefix op on entire AVX register
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
@ -31,19 +53,7 @@ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
}
def step = 256/w
def V = [step]T
p:= broadcast{V, id}
xv:= *V ~~ x
rv:= *V ~~ r
e:= len/step
@for (xv, rv over e) {
n:= op{pre{xv}, p}
p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}}
rv = n
}
q:= len & (step-1)
if (q) maskstoreF{rv, maskOf{V, q}, e, op{pre{load{xv,e}}, p}}
scan_post{T, id, x, r, len, op, pre}
}
def avx2_scan_idem{T, op} = {
def m = 1 << (width{T}-1)
@ -56,6 +66,25 @@ def avx2_scan_idem{T, op} = {
'avx2_scan_min32' = avx2_scan_idem{i32, min}
'avx2_scan_max32' = avx2_scan_idem{i32, max}
# Associative scan
avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
# Prefix op on entire AVX register
def pre{a} = {
# Within each lane, scan using shifts by powers of 2.
# Assumes identity is 0.
def w = width{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
}
scan_post{T, init, x, r, len, op, pre}
}
'avx2_scan_pluswrap_u8' = avx2_scan_assoc_0{u8 , +}
'avx2_scan_pluswrap_u16' = avx2_scan_assoc_0{u16, +}
'avx2_scan_pluswrap_u32' = avx2_scan_assoc_0{u32, +}
# Boolean cumulative sum
avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
rv:= *[8]u32~~r