Merge pull request #91 from mlochbaum/scan

Scan refactoring and architecture extension
This commit is contained in:
dzaima 2023-08-26 17:02:54 +03:00 committed by GitHub
commit 4f5188a51e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 233 additions and 196 deletions

View File

@ -650,8 +650,8 @@ cachedBin‿linkerCache ← {
"xag""src/builtins/search.c""search", "xa.""src/builtins/fold.c""fold",
"xag""src/builtins/sort.c""bins"
"2..""src/builtins/select.c""select", "2..""src/builtins/scan.c""scan",
"2a.""src/builtins/slash.c""constrep", "2..""src/builtins/scan.c""neq",
"2..""src/builtins/select.c""select", "xag""src/builtins/scan.c""scan",
"2a.""src/builtins/slash.c""constrep",
"xag""src/builtins/slash.c""slash", "2..""src/builtins/slash.c""count"
objs

View File

@ -78,22 +78,22 @@
}
#if SINGELI_AVX2
extern void (*const avx2_scan_max_i8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_min_i8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_max_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const avx2_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const si_scan_max_i8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const si_scan_min_i8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const si_scan_max_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const si_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
#define COUNT_THRESHOLD 32
#define WRITE_SPARSE_i8 \
for (usz i=0; i<n; i++) rp[i]=j; \
while (ij<n) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
GRADE_UD(avx2_scan_max_i8,avx2_scan_min_i8)(rp,rp,n);
GRADE_UD(si_scan_max_i8,si_scan_min_i8)(rp,rp,n);
#define WRITE_SPARSE_i16 \
usz b = 1<<10; \
for (usz k=0; ; ) { \
usz e = b<n-k? k+b : n; \
for (usz i=k; i<e; i++) rp[i]=j; \
while (ij<e) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
GRADE_UD(avx2_scan_max_i16,avx2_scan_min_i16)(rp+k,rp+k,e-k); \
GRADE_UD(si_scan_max_i16,si_scan_min_i16)(rp+k,rp+k,e-k); \
if (e==n) {break;} k=e; \
}
#define WRITE_SPARSE(T) WRITE_SPARSE_##T

View File

@ -11,15 +11,15 @@
#define RDX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { RDX_PRE(0); RDX_PRE(1); }
#define RDX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { RDX_PRE(0); RDX_PRE(1); RDX_PRE(2); RDX_PRE(3); }
#if SINGELI_AVX2
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0);
#define RADIX_SUM_1_u32 avx2_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0);
#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0);
#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#if SINGELI_X86_64
extern void (*const si_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const si_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define RADIX_SUM_1_u8 si_scan_pluswrap_u8 (c0,c0, 256,0);
#define RADIX_SUM_1_u32 si_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_u8 si_scan_pluswrap_u8 (c0,c0,2*256,0);
#define RADIX_SUM_2_u32 si_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_u8 si_scan_pluswrap_u8 (c0,c0,4*256,0);
#define RADIX_SUM_4_u32 si_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_u8 RDX_SUM_1(u8)
#define RADIX_SUM_1_u32 RDX_SUM_1(u32)
@ -29,10 +29,10 @@ extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v
#define RADIX_SUM_4_u32 RDX_SUM_4(u32)
#endif
#if SINGELI_AVX2 && !USZ_64
#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0);
#if SINGELI_X86_64 && !USZ_64
#define RADIX_SUM_1_usz si_scan_pluswrap_u32(c0,c0, 256,0);
#define RADIX_SUM_2_usz si_scan_pluswrap_u32(c0,c0,2*256,0);
#define RADIX_SUM_4_usz si_scan_pluswrap_u32(c0,c0,4*256,0);
#else
#define RADIX_SUM_1_usz RDX_SUM_1(usz)
#define RADIX_SUM_2_usz RDX_SUM_2(usz)

View File

@ -9,21 +9,17 @@
static u64 vg_rand(u64 x) { return x; }
#endif
#if SINGELI_AVX2
#if SINGELI
#define SINGELI_FILE scan
#include "../utils/includeSingeli.h"
#if __PCLMUL__
#define SINGELI_FILE neq
#include "../utils/includeSingeli.h"
#endif
#endif
B scan_ne(B x, u64 p, u64 ia) { // consumes x
u64* xp = bitarr_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia);
#if SINGELI_AVX2 && __PCLMUL__
clmul_scan_ne(p, xp, rp, BIT_N(ia));
#if SINGELI
si_scan_ne(p, xp, rp, BIT_N(ia));
#if USE_VALGRIND
if (ia&63) rp[ia>>6] = vg_def_u64(rp[ia>>6]);
#endif
@ -78,8 +74,8 @@ B scan_add_bool(B x, u64 ia) { // consumes x
} else {
void* rp = m_tyarrv(&r, elWidth(re), ia, el2t(re));
#define SUM_BITWISE(T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } }
#if SINGELI_AVX2
#define SUM(W,T) avx2_bcs##W(xp, rp, ia);
#if SINGELI
#define SUM(W,T) si_bcs##W(xp, rp, ia);
#else
#define SUM(W,T) SUM_BITWISE(T)
#endif
@ -96,8 +92,8 @@ B scan_add_bool(B x, u64 ia) { // consumes x
}
// min/max-scan
#if SINGELI_AVX2
#define MINMAX_SCAN(T,NAME,C,I) avx2_scan_##NAME##_init_##T(xp, rp, ia, I);
#if SINGELI
#define MINMAX_SCAN(T,NAME,C,I) si_scan_##NAME##_init_##T(xp, rp, ia, I);
#else
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
#endif
@ -155,16 +151,16 @@ static B scan_lt(B x, u64 p, usz ia) {
static B scan_plus(f64 r0, B x, u8 xe, usz ia) {
assert(xe!=el_bit && elNum(xe));
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
#if SINGELI_AVX2
#if SINGELI
switch(xe) { default:UD;
case el_i8: { if (!q_fi32(r0) || simd_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
case el_i16: { if (!q_fi32(r0) || simd_scan_plus_i16_i32(i16any_ptr(x), r0, rp, ia)!=ia) goto cs_i16_f64; decG(x); return r; }
case el_i32: { if (!q_fi32(r0) || simd_scan_plus_i32_i32(i32any_ptr(x), r0, rp, ia)!=ia) goto cs_i32_f64; decG(x); return r; }
case el_i8: { if (!q_fi32(r0) || si_scan_plus_i8_i32 (i8any_ptr(x), r0, rp, ia)!=ia) goto cs_i8_f64; decG(x); return r; }
case el_i16: { if (!q_fi32(r0) || si_scan_plus_i16_i32(i16any_ptr(x), r0, rp, ia)!=ia) goto cs_i16_f64; decG(x); return r; }
case el_i32: { if (!q_fi32(r0) || si_scan_plus_i32_i32(i32any_ptr(x), r0, rp, ia)!=ia) goto cs_i32_f64; decG(x); return r; }
case el_f64: { f64* xp=f64any_ptr(x); f64 c=r0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
}
cs_i8_f64: { x=taga(cpyI16Arr(x)); goto cs_i16_f64; }
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); simd_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); simd_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
cs_i16_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i16_f64(i16any_ptr(x), r0, rp, ia); decG(x); return r; }
cs_i32_f64: { decG(r); f64* rp; r = m_f64arrv(&rp, ia); si_scan_plus_i32_f64(i32any_ptr(x), r0, rp, ia); decG(x); return r; }
#else
if (xe==el_i8 && q_fi32(r0)) { i8* xp=i8any_ptr (x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i16 && q_fi32(r0)) { i16* xp=i16any_ptr(x); i32 c=r0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) goto base; ((i32*)rp)[i]=c; } decG(x); return r; }

View File

@ -98,13 +98,18 @@
#if SINGELI_AVX2
#define SINGELI_FILE count
#include "../utils/includeSingeli.h"
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
#define PLUS_SCAN(T) avx2_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
extern void (*const avx2_scan_max_i32)(int32_t* v0,int32_t* v1,uint64_t v2);
#endif
#if SINGELI
extern void (*const si_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
extern void (*const si_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
extern void (*const si_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define ALIAS(I,U) static void si_scan_pluswrap_##I(I* a, I* b, u64 c, I d) { si_scan_pluswrap_##U((U*)a, (U*)b, c, d); }
ALIAS(i8,u8) ALIAS(i16,u16) ALIAS(i32,u32)
#undef ALIAS
#define si_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
#define PLUS_SCAN(T) si_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
extern void (*const si_scan_max_i32)(int32_t* v0,int32_t* v1,uint64_t v2);
#else
#define PLUS_SCAN(T) for (usz i=k; i<e; i++) js=rp[i]+=js;
#endif
@ -548,7 +553,7 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
#if SINGELI_AVX2
#define IND_BY_SCAN \
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, avx2_scan_max_i32(rp+k,rp+k,e-k))
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, si_scan_max_i32(rp+k,rp+k,e-k))
#else
#define IND_BY_SCAN usz js=0; SUM_CORE(i32, xp[j], , 1)
#endif

View File

@ -4,6 +4,8 @@ if (hasarch{'AVX2'}) {
include './sse'
include './avx'
include './avx2'
} else if (hasarch{'X86_64'}) {
include './sse'
}
include './mask'
include 'util/tup'
@ -27,36 +29,18 @@ def shr16{v:V, n} = V~~(re_el{u16, v} >> n)
# Forward or backwards in-place max-scan
# Assumes a whole number of vectors and minimum 0
include './scan_common'
fn max_scan{T, up}(x:*T, len:u64) : void = {
def w = width{T}
if (hasarch{'AVX2'} and T!=u64) {
if (hasarch{'X86_64'}) {
def op = max
# TODO unify with scan.singeli avx2_scan_idem
def rev{a} = if (up) a else (tuplen{a}-1)-reverse{a}
def maker{T, l} = make{T, rev{l}}
def sel8{v, t} = sel{[16]u8, v, maker{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,rev{n}}}
def spread{a:VT} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
op{b, sel{[8]i32, spread{b}, maker{[8]i32, 3*(3<iota{8})}}}
}
def toLast{n:VT} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**(up*7)}
else shuf{[4]u64, n, up*4b3333}
}
def vl = 256/w
def pre = make_scan_idem{T, op, up}
def vl = (if (hasarch{'AVX2'}) 256 else 128)/width{T}
def V = [vl]T
p := V**0
@for_dir{up} (v in *V~~x over len/vl) { v = op{pre{v}, p}; p = toLast{v} }
@for_dir{up} (v in *V~~x over len/vl) {
v = op{pre{v}, p}
p = toLast{v, up}
}
} else {
m:T=0; @for_dir{up} (x over len) { if (x > m) m = x; x = m }
}

View File

@ -1,29 +0,0 @@
include './base'
include './sse'
include './clmul'
fn clmul_scan_ne_any(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
def V = [2]u64
m := V**mark
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
p := clmul{a, m, i}
t := shr{[16]u8, p, 8}
s := p ^ carry
carry = s ^ t
s
}
xv := *V ~~ x
rv := *V ~~ r
e := words/2
c := V**init
@for (rv, xv over e) {
rv = apply{zipLo, (@collect (j to 2) xor64{xv, j, c})}
}
if (words & 1) {
storeLow{rv+e, 64, clmul{loadLow{xv+e, 64}, m, 0} ^ c}
}
}
fn clmul_scan_ne_bit(init:u64, x:*u64, r:*u64, ia:u64) : void = {
clmul_scan_ne_any(*void~~x, *void~~r, init, ia, -(u64~~1))
}
export{'clmul_scan_ne', clmul_scan_ne_bit}

View File

@ -1,30 +1,19 @@
include './base'
include './sse'
include './clmul'
include './avx'
include './avx2'
include './mask'
include './f64'
include './scan_common'
def sel8{v, t} = sel{[16]u8, v, make{[32]i8, t}}
def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
# Fill last 4 bytes with last element, in each lane
def spread{a:VT} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
}
# Set all elements with the last element of the input
def toLast{n:VT} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7}
else shuf{[4]u64, n, 4b3333}
# Initialized scan, generic implementation
fn scan_scal{T, op}(x:*T, r:*T, len:u64, m:T) : void = {
@for (x, r over len) r = m = op{m, x}
}
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
def step = 256/width{T}
def step = arch_defvw/width{T}
def V = [step]T
p:= V**init
xv:= *V ~~ x
@ -45,64 +34,85 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
fn avx2_scan_idem{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def w = width{T}
def shift{k,l} = merge{iota{k},iota{l-k}}
def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}}
def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a)
# Prefix op on entire AVX register
def pre{a} = {
b:= c8{2, c8{1, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
}
def scan_idem = scan_scal
fn scan_idem{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
scan_post{T, init, x, r, len, op, make_scan_idem{T, op}}
}
scan_post{T, init, x, r, len, op, pre}
}
fn avx2_scan_idem{T==f64, op}(x:*T, r:*T, len:u64, init:T) : void = {
def sh{s, a} = op{a, shuf{[4]u64, a, s}}
scan_post{T, init, x, r, len, op, {a}=>sh{4b1110,sh{4b2200,a}}}
}
export{'avx2_scan_min_init_i8', avx2_scan_idem{i8 , min}}; export{'avx2_scan_max_init_i8', avx2_scan_idem{i8 , max}}
export{'avx2_scan_min_init_i16', avx2_scan_idem{i16, min}}; export{'avx2_scan_max_init_i16', avx2_scan_idem{i16, max}}
export{'avx2_scan_min_init_i32', avx2_scan_idem{i32, min}}; export{'avx2_scan_max_init_i32', avx2_scan_idem{i32, max}}
export{'avx2_scan_min_init_f64', avx2_scan_idem{f64, min}}; export{'avx2_scan_max_init_f64', avx2_scan_idem{f64, max}}
export{'si_scan_min_init_i8', scan_idem{i8 , min}}; export{'si_scan_max_init_i8', scan_idem{i8 , max}}
export{'si_scan_min_init_i16', scan_idem{i16, min}}; export{'si_scan_max_init_i16', scan_idem{i16, max}}
export{'si_scan_min_init_i32', scan_idem{i32, min}}; export{'si_scan_max_init_i32', scan_idem{i32, max}}
export{'si_scan_min_init_f64', scan_idem{f64, min}}; export{'si_scan_max_init_f64', scan_idem{f64, max}}
fn avx2_scan_idem_id{T, op}(x:*T, r:*T, len:u64) : void = {
def m = 1 << (width{T}-1)
def id = (if (same{op,min}) m-1; else -m)
avx2_scan_idem{T, op}(x, r, len, id)
fn scan_idem_id{T, op}(x:*T, r:*T, len:u64) : void = {
scan_idem{T, op}(x, r, len, get_id{op, T})
}
export{'avx2_scan_min_i8', avx2_scan_idem_id{i8 , min}}; export{'avx2_scan_max_i8', avx2_scan_idem_id{i8 , max}}
export{'avx2_scan_min_i16', avx2_scan_idem_id{i16, min}}; export{'avx2_scan_max_i16', avx2_scan_idem_id{i16, max}}
export{'avx2_scan_min_i32', avx2_scan_idem_id{i32, min}}; export{'avx2_scan_max_i32', avx2_scan_idem_id{i32, max}}
export{'si_scan_min_i8', scan_idem_id{i8 , min}}; export{'si_scan_max_i8', scan_idem_id{i8 , max}}
export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', scan_idem_id{i16, max}}
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
# Assumes identity is 0
def scan_assoc{op, a:T} = {
# Within each lane, scan using shifts by powers of 2
def w = elwidth{T}
def c32{k, a} = (if (w<=8*k) op{a, shl{[16]u8, a, k}}; else a)
b:= c32{8, c32{4, c32{2, c32{1, a}}}}
# After lanewise scan, broadcast end of lane 0 to entire lane 1
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
def scan_assoc{op} = {
def shl0{v, k} = shl{[16]u8, v, k/8} # Lanewise
def shl0{v:V, k==128 & hasarch{'AVX2'}} = {
# Broadcast end of lane 0 to entire lane 1
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
}
prefix_byshift{op, shl0}
}
def scan_plus = scan_assoc{+, .}
def scan_plus = scan_assoc{+}
# Associative scan
fn avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
def scan_assoc_0 = scan_scal
fn scan_assoc_0{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Prefix op on entire AVX register
scan_post{T, init, x, r, len, op, scan_plus}
}
export{'avx2_scan_pluswrap_u8', avx2_scan_assoc_0{u8 , +}}
export{'avx2_scan_pluswrap_u16', avx2_scan_assoc_0{u16, +}}
export{'avx2_scan_pluswrap_u32', avx2_scan_assoc_0{u32, +}}
export{'si_scan_pluswrap_u8', scan_assoc_0{u8 , +}}
export{'si_scan_pluswrap_u16', scan_assoc_0{u16, +}}
export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}}
# xor scan
fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = {
@for (x, r over nw) {
r = p ^ prefix_byshift{^, <<}{x}
p = -(r>>63) # repeat sign bit
}
}
fn clmul_scan_ne_any{..._ & hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
def V = [2]u64
m := V**mark
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
p := clmul{a, m, i}
t := shr{[16]u8, p, 8}
s := p ^ carry
carry = s ^ t
s
}
xv := *V ~~ x
rv := *V ~~ r
e := words/2
c := V**init
@for (rv, xv over e) {
rv = apply{zipLo, (@collect (j to 2) xor64{xv, j, c})}
}
if (words & 1) {
storeLow{rv+e, 64, clmul{loadLow{xv+e, 64}, m, 0} ^ c}
}
}
fn scan_neq{..._ & hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
}
export{'si_scan_ne', scan_neq{}}
# Boolean cumulative sum
fn avx2_bcs{T}(x:*u64, r:*T, l:u64) : void = {
fn bcs{T}(x:*u64, r:*T, l:u64) : void = {
def bitp_get{arr, n} = (load{arr, n>>6} >> (n&63)) & 1
c:T = 0
@for (r over i to l) { c+= cast_i{T, bitp_get{x,i}}; r = c }
}
fn bcs{T & hasarch{'AVX2'}}(x:*u64, r:*T, l:u64) : void = {
def U = ty_u{T}
def w = width{T}
def vl= 256 / w
@ -157,9 +167,9 @@ fn avx2_bcs{T}(x:*u64, r:*T, l:u64) : void = {
step{load{xv, e}, e, st}
}
}
export{'avx2_bcs8', avx2_bcs{i8}}
export{'avx2_bcs16', avx2_bcs{i16}}
export{'avx2_bcs32', avx2_bcs{i32}}
export{'si_bcs8', bcs{i8}}
export{'si_bcs16', bcs{i16}}
export{'si_bcs32', bcs{i32}}
@ -190,7 +200,23 @@ def maxabsval{T & issigned{T}} = -minvalue{T}
def maxsafeint{T & issigned{T}} = maxvalue{T}
def maxsafeint{T==f64} = 1<<53
def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
fn plus_scan{X, R, O}(x:*X, c:R, r:*R, len:u64) : O = {
i:u64 = 0
if (hasarch{'AVX2'}) simd_plus_scan_part{X,R}{x, c, r, len, i}
@forUnroll{1,1} (js from i to len) {
def vs = eachx{load, x, js}
each{{j, v} => {
def {b,n} = addChk{c, promote{R, v}}
if (rare{b}) return{j}
store{r, j, n}
c = n
}, js, vs}
}
len
}
# Sum as many vector registers as possible; modifies c and i
def simd_plus_scan_part{X, R}{x:*X, c:(R), r:*R, len:u64, i:u64} = {
def b = max{width{R}/2, width{X}}
def bulk = arch_defvw/b
def wd = (X!=R) & (width{X}<32) # whether to widen the working copy one size
@ -203,7 +229,6 @@ def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
if (R!=f64) { def m = maxFastA + maxFastE*bulk; assert{m<=maxvalue{R}}; assert{-m>=minvalue{R}} }
i:u64 = 0
cv:= [arch_defvw/width{R}]R ** c
if (R==f64 and c != floor{c}) goto{'end'}
@ -237,24 +262,13 @@ def simd_plus_scan{X, b, R}{x:*X, c:(R), r:*R, len:u64} = {
setlabel{'end'}
c = extract{cv, 0}
@forUnroll{1,1} (js from i to len) {
def vs = eachx{load, x, js}
each{{j, v} => {
def {b,n} = addChk{c, promote{R, v}}
if (rare{b}) return{j}
store{r, j, n}
c = n
}, js, vs}
}
len
}
fn simd_plus_scanG{X, b, R}(x:*X, c:R, r:*R, len:u64) : void = simd_plus_scan{X,b,R}{x, c, r, len}
fn simd_plus_scanC{X, b, R}(x:*X, c:R, r:*R, len:u64) : u64 = simd_plus_scan{X,b,R}{x, c, r, len}
def plus_scanG{X, R} = plus_scan{X, R, void}
def plus_scanC{X, R} = plus_scan{X, R, u64}
export{'simd_scan_plus_i8_i32', simd_plus_scanC{i8, 16, i32}}
export{'simd_scan_plus_i16_i32', simd_plus_scanC{i16, 16, i32}}
export{'simd_scan_plus_i32_i32', simd_plus_scanC{i32, 32, i32}}
export{'si_scan_plus_i8_i32', plus_scanC{i8, i32}}
export{'si_scan_plus_i16_i32', plus_scanC{i16, i32}}
export{'si_scan_plus_i32_i32', plus_scanC{i32, i32}}
export{'simd_scan_plus_i16_f64', simd_plus_scanG{i16, 32, f64}}
export{'simd_scan_plus_i32_f64', simd_plus_scanG{i32, 32, f64}}
export{'si_scan_plus_i16_f64', plus_scanG{i16, f64}}
export{'si_scan_plus_i32_f64', plus_scanG{i32, f64}}

View File

@ -0,0 +1,69 @@
# Used by scan.singeli and bins.singeli
def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}}
def sel8{v:V, t & w256{V} & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
local def rev{t} = { def l=tuplen{t}; def j=l-1; tupsel{j-range{l}, j-t} }
local def rev{up,t} = if (up) t else rev{t}
def sel8{v, t, up} = sel8{v, rev{up,t}}
def zip{up, x} = (if (up) zipHi else zipLo){x,x}
# Fill last 4 bytes with last element, in each lane
def spread{a:VT, ...up} = {
def w = elwidth{VT}
def b = w/8
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}, ...up}; else a
}
# Set all elements with the last element of the input
def toLast{n:VT, up & hasarch{'X86_64'} & w128{VT}} = {
def l{v, w} = l{zip{up,v}, 2*w}
def l{v, w & hasarch{'SSSE3'}} = sel8{v, up*(16-w/8)+iota{16}%(w/8)}
def l{v, w & w>=32} = shuf{[4]i32, v, 4**(up*3)}
l{n, elwidth{VT}}
}
def toLast{n:VT, up & hasarch{'AVX2'} & w256{VT}} = {
if (elwidth{VT}<=32) sel{[8]i32, spread{n,up}, [8]i32**(up*7)}
else shuf{[4]u64, n, 4**(up*3)}
}
def toLast{n:VT} = toLast{n, 1}
# Make prefix scan from op and shifter by applying the operation
# at increasing power-of-two shifts
def prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
}
def get_id{op,T} = (match (op) { {_==min}=>maxvalue; {_==max}=>minvalue }){T}
def make_scan_idem{T, op, up} = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def shift{k,l} = rev{up, merge{iota{k},iota{l-k}}}
def shb{v:V, k} = {
def w=width{T}; def c = k/w
def merger{a,b} = if (up) merge{a,b} else merge{b,a}
def id = make{V, merger{c**get_id{op,T}, (width{V}/w-c)**0}}
(if (up) shl else shr){[16]u8, v, k/8} | id
}
def shb{v, k & hasarch{'SSSE3'}} = sel8{v, shift{k/8,16}}
def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
def shb{v, k & k==128 & hasarch{'AVX2'}} = {
# After lanewise scan, broadcast end of lane 0 to entire lane 1
sel{[8]i32, spread{v,up}, make{[8]i32, rev{up,3*(3<iota{8})}}}
}
prefix_byshift{op, shb}
}
def make_scan_idem{T==f64, op, up} = {
def sc{a} = op{a, zip{up,a}}
def sc{a & hasarch{'AVX2'}} = {
def sh{s, a} = op{a, shuf{[4]u64, a, rev{up,s}}}
sh{tup{0,1,1,1},sh{tup{0,0,2,2},a}}
}
sc
}
def make_scan_idem{T, op} = make_scan_idem{T, op, 1}

View File

@ -11,18 +11,14 @@ def extract{x:T, i & w128i{T, 8} & knum{i}} = emit{eltype{T}, '_mm_extract_epi8'
def extract{x:T, i & w128i{T,32} & knum{i}} = emit{eltype{T}, '_mm_extract_epi32', x, i}
def extract{x:T, i & w128i{T,64} & knum{i}} = emit{eltype{T}, '_mm_extract_epi64', x, i}
def andAllZero{x:T, y:T & w128i{T}} = emit{u1, '_mm_testz_si128', x, y}
def min{a:T,b:T & T==[ 8]u16} = emit{T, '_mm_min_epu16', a, b}; def max{a:T,b:T & T==[ 8]u16} = emit{T, '_mm_max_epu16', a, b}
def min{a:T,b:T & T==[16]i8 } = emit{T, '_mm_min_epi8', a, b}; def max{a:T,b:T & T==[16]i8 } = emit{T, '_mm_max_epi8', a, b}
def min{a:T,b:T & T==[ 4]u32} = emit{T, '_mm_min_epu32', a, b}; def max{a:T,b:T & T==[ 4]u32} = emit{T, '_mm_max_epu32', a, b}
def min{a:T,b:T & T==[ 4]i32} = emit{T, '_mm_min_epi32', a, b}; def max{a:T,b:T & T==[ 4]i32} = emit{T, '_mm_max_epi32', a, b}
def __le{a:T,b:T & w128u{T}} = a==min{a,b}
def __ge{a:T,b:T & w128u{T}} = a==max{a,b}
# arith
def min{a:T,b:T & T==[16]i8 } = emit{T, '_mm_min_epi8', a, b}; def max{a:T,b:T & T==[16]i8 } = emit{T, '_mm_max_epi8', a, b}
def min{a:T,b:T & T==[ 4]i32} = emit{T, '_mm_min_epi32', a, b}; def max{a:T,b:T & T==[ 4]i32} = emit{T, '_mm_max_epi32', a, b}
def min{a:T,b:T & T==[ 8]u16} = emit{T, '_mm_min_epu16', a, b}; def max{a:T,b:T & T==[ 8]u16} = emit{T, '_mm_max_epu16', a, b}
def min{a:T,b:T & T==[ 4]u32} = emit{T, '_mm_min_epu32', a, b}; def max{a:T,b:T & T==[ 4]u32} = emit{T, '_mm_max_epu32', a, b}
def min{a:T,b:T & T==[16]i8 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epi8', a, b}; def max{a:T,b:T & T==[16]i8 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epi8', a, b}
def min{a:T,b:T & T==[ 4]i32 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epi32', a, b}; def max{a:T,b:T & T==[ 4]i32 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epi32', a, b}
def min{a:T,b:T & T==[ 8]u16 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epu16', a, b}; def max{a:T,b:T & T==[ 8]u16 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epu16', a, b}
def min{a:T,b:T & T==[ 4]u32 & hasarch{'SSE4.1'}} = emit{T, '_mm_min_epu32', a, b}; def max{a:T,b:T & T==[ 4]u32 & hasarch{'SSE4.1'}} = emit{T, '_mm_max_epu32', a, b}
def __le{a:T,b:T & w128u{T}} = a==min{a,b}
def __ge{a:T,b:T & w128u{T}} = a==max{a,b}
def __eq{a:T,b:T & w128i{T,64}} = emit{[2]u64, '_mm_cmpeq_epi64', a, b}

View File

@ -189,6 +189,8 @@ def packQ{a:T,b:T & w128i{T}} = packs{a,b}
def zipLo{a:T, b:T & w128i{T}} = emit{T, merge{'_mm_unpacklo_epi',fmtnat{elwidth{T}}}, a, b}
def zipHi{a:T, b:T & w128i{T}} = emit{T, merge{'_mm_unpackhi_epi',fmtnat{elwidth{T}}}, a, b}
def zipLo{a:T, b:T & w128f{T}} = emit{T, merge{'_mm_unpacklo_p',if (elwidth{T}==32) 's' else 'd'}, a, b}
def zipHi{a:T, b:T & w128f{T}} = emit{T, merge{'_mm_unpackhi_p',if (elwidth{T}==32) 's' else 'd'}, a, b}
def zip{a:T, b:T & w128i{T}} = tup{zipLo{a,b}, zipHi{a,b}}