Macro-ize integer min/max scan and add sorted flag

This commit is contained in:
Marshall Lochbaum 2022-11-16 20:49:57 -05:00
parent ac7ff155d4
commit 136c1afacc
4 changed files with 51 additions and 43 deletions

View File

@ -31,22 +31,22 @@
}
#if SINGELI
extern void (*const avx2_scan_max8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_min8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_max16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const avx2_scan_max_i8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_min_i8)(int8_t* v0,int8_t* v1,uint64_t v2);
extern void (*const avx2_scan_max_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
extern void (*const avx2_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2);
#define COUNT_THRESHOLD 32
#define WRITE_SPARSE_i8 \
for (usz i=0; i<n; i++) rp[i]=j; \
while (ij<n) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
GRADE_UD(avx2_scan_max8,avx2_scan_min8)(rp,rp,n);
GRADE_UD(avx2_scan_max_i8,avx2_scan_min_i8)(rp,rp,n);
#define WRITE_SPARSE_i16 \
usz b = 1<<10; \
for (usz k=0; ; ) { \
usz e = b<n-k? k+b : n; \
for (usz i=k; i<e; i++) rp[i]=j; \
while (ij<e) { rp[ij]=GRADE_UD(++j,--j); ij+=c0o[j]; } \
GRADE_UD(avx2_scan_max16,avx2_scan_min16)(rp+k,rp+k,e-k); \
GRADE_UD(avx2_scan_max_i16,avx2_scan_min_i16)(rp+k,rp+k,e-k); \
if (e==n) {break;} k=e; \
}
#define WRITE_SPARSE(T) WRITE_SPARSE_##T

View File

@ -19,7 +19,8 @@ static u64 vg_rand(u64 x) { return x; }
#include "../singeli/gen/neq.c"
#pragma GCC diagnostic pop
#endif
B scan_ne(u64 p, u64* xp, u64 ia) {
B scan_ne(B x, u64 p, u64 ia) { // consumes x
u64* xp = bitarr_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia);
#if SINGELI && __PCLMUL__
clmul_scan_ne(p, xp, rp, BIT_N(ia));
@ -33,16 +34,23 @@ B scan_ne(u64 p, u64* xp, u64 ia) {
p = -(r>>63); // repeat sign bit
}
#endif
return r;
decG(x); return r;
}
static B scan_or(B x, u64 ia) {
static B scan_or(B x, u64 ia) { // consumes x
u64* xp = bitarr_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia);
usz n=BIT_N(ia); u64 xi; usz i=0;
while (i<n) if ((xi=vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi); i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0;
while (i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]=0;
decG(x); return FL_SET(r, fl_asc|fl_squoze);
}
static B scan_and(B x, u64 ia) { // consumes x
u64* xp = bitarr_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia);
usz n=BIT_N(ia); u64 xi; usz i=0;
while (i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL;
decG(x); return FL_SET(r, fl_dsc|fl_squoze);
}
B slash_c1(B f, B x);
B scan_add_bool(B x, u64 ia) { // consumes x
@ -84,6 +92,26 @@ B scan_add_bool(B x, u64 ia) { // consumes x
return FL_SET(r, fl_asc|fl_squoze);
}
#if SINGELI
#define MINMAX_SCAN(T,NAME,C,I) avx2_scan_##NAME##_##T(xp, rp, ia);
#else
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
#endif
#define MM_CASE(T,N,C,I) \
case el_##T : { T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,I); break; }
#define MINMAX(NAME,C,INIT,ORD) \
B r; switch (xe) { default:UD; \
MM_CASE(i8 ,NAME,C,I8_##INIT ) \
MM_CASE(i16,NAME,C,I16_##INIT) \
MM_CASE(i32,NAME,C,I32_##INIT) \
} \
decG(x); return FL_SET(r, fl_##ORD);
static B scan_min_int(B x, u8 xe, usz ia) { MINMAX(min,<,MAX,dsc) }
static B scan_max_int(B x, u8 xe, usz ia) { MINMAX(max,>,MIN,asc) }
#undef MM_CASE
#undef MINMAX
#undef MINMAX_SCAN
B scan_c1(Md1D* d, B x) { B f = d->f;
if (isAtm(x) || RNK(x)==0) thrM("`: Argument cannot have rank 0");
ur xr = RNK(x);
@ -97,8 +125,8 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
u64* xp=bitarr_ptr(x);
if (rtid==n_add ) return scan_add_bool(x, ia);
if (rtid==n_or | rtid==n_ceil ) return scan_or(x, ia);
if (rtid==n_and | rtid==n_mul | rtid==n_floor) { u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); u64 xi; usz i=0; while(i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL; decG(x); return r; }
if (rtid==n_ne) { B r=scan_ne(0, xp, ia); decG(x); return r; }
if (rtid==n_and | rtid==n_mul | rtid==n_floor) return scan_and(x, ia);
if (rtid==n_ne ) return scan_ne(x, 0, ia);
if (rtid==n_lt) {
u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia);
u64 m10 = 0x5555555555555555;
@ -117,28 +145,8 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=0; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=0; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
}
if (rtid==n_floor) { // ⌊
#if SINGELI
if (xe==el_i8 ) { i8* rp; B r=m_i8arrv (&rp, ia); avx2_scan_min8 (i8any_ptr (x), rp, ia); decG(x); return r; }
if (xe==el_i16) { i16* rp; B r=m_i16arrv(&rp, ia); avx2_scan_min16(i16any_ptr(x), rp, ia); decG(x); return r; }
if (xe==el_i32) { i32* rp; B r=m_i32arrv(&rp, ia); avx2_scan_min32(i32any_ptr(x), rp, ia); decG(x); return r; }
#else
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i8* rp; B r=m_i8arrv (&rp, ia); i8 c=I8_MAX ; for (usz i=0; i<ia; i++) { if (xp[i]<c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); i16* rp; B r=m_i16arrv(&rp, ia); i16 c=I16_MAX; for (usz i=0; i<ia; i++) { if (xp[i]<c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=I32_MAX; for (usz i=0; i<ia; i++) { if (xp[i]<c)c=xp[i]; rp[i]=c; } decG(x); return r; }
#endif
}
if (rtid==n_ceil) { // ⌈
#if SINGELI
if (xe==el_i8 ) { i8* rp; B r=m_i8arrv (&rp, ia); avx2_scan_max8 (i8any_ptr (x), rp, ia); decG(x); return r; }
if (xe==el_i16) { i16* rp; B r=m_i16arrv(&rp, ia); avx2_scan_max16(i16any_ptr(x), rp, ia); decG(x); return r; }
if (xe==el_i32) { i32* rp; B r=m_i32arrv(&rp, ia); avx2_scan_max32(i32any_ptr(x), rp, ia); decG(x); return r; }
#else
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i8* rp; B r=m_i8arrv (&rp, ia); i8 c=I8_MIN ; for (usz i=0; i<ia; i++) { if (xp[i]>c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); i16* rp; B r=m_i16arrv(&rp, ia); i16 c=I16_MIN; for (usz i=0; i<ia; i++) { if (xp[i]>c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=I32_MIN; for (usz i=0; i<ia; i++) { if (xp[i]>c)c=xp[i]; rp[i]=c; } decG(x); return r; }
#endif
}
if (rtid==n_floor && xe<el_f64) return scan_min_int(x, xe, ia); // ⌊
if (rtid==n_ceil && xe<el_f64) return scan_max_int(x, xe, ia); // ⌈
if (rtid==n_ne) { // ≠
f64 x0 = IGetU(x,0).f; if (x0!=0 && x0!=1) goto base;
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); u64* rp; B r=m_bitarrv(&rp,ia); bool c=x0; rp[0]=c; for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
@ -183,7 +191,7 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
if (xe==el_bit) {
u64* xp=bitarr_ptr(x);
if (rtid==n_add) { i32* rp; B r=m_i32arrv(&rp, ia); i64 c=wv; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); rp[i]=c; } decG(x); return r; }
if (rtid==n_ne) { B r=scan_ne(-(u64)(q_ibit(wv)?wv:1&~*xp), xp, ia); decG(x); return r; }
if (rtid==n_ne) return scan_ne(x, -(u64)(q_ibit(wv)?wv:1&~*xp), ia);
goto base;
}
if (rtid==n_add) { // +

View File

@ -168,7 +168,7 @@
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
#define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i<e; i++) js=rp[i]+=js;
#define PLUS_SCAN(T) avx2_scan_pluswrap_##T(rp+k,rp+k,e-k,js); js=rp[e-1];
extern void (*const avx2_scan_max32)(int32_t* v0,int32_t* v1,uint64_t v2);
extern void (*const avx2_scan_max_i32)(int32_t* v0,int32_t* v1,uint64_t v2);
#else
#define PLUS_SCAN(T) for (usz i=k; i<e; i++) js=rp[i]+=js;
#endif
@ -607,7 +607,7 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
#if SINGELI
#define IND_BY_SCAN \
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, avx2_scan_max32(rp+k,rp+k,e-k))
SCAN_CORE(xp[j], rp[ij]=j, rp[k]=j, avx2_scan_max_i32(rp+k,rp+k,e-k))
#else
#define IND_BY_SCAN usz js=0; SUM_CORE(i32, xp[j], , 1)
#endif

View File

@ -59,12 +59,12 @@ def avx2_scan_idem{T, op} = {
def m = 1 << (width{T}-1)
avx2_scan_idem{T, op, (if (match{op,min}) m-1; else -m)}
}
'avx2_scan_min8' = avx2_scan_idem{i8 , min}
'avx2_scan_max8' = avx2_scan_idem{i8 , max}
'avx2_scan_min16' = avx2_scan_idem{i16, min}
'avx2_scan_max16' = avx2_scan_idem{i16, max}
'avx2_scan_min32' = avx2_scan_idem{i32, min}
'avx2_scan_max32' = avx2_scan_idem{i32, max}
'avx2_scan_min_i8' = avx2_scan_idem{i8 , min}
'avx2_scan_max_i8' = avx2_scan_idem{i8 , max}
'avx2_scan_min_i16' = avx2_scan_idem{i16, min}
'avx2_scan_max_i16' = avx2_scan_idem{i16, max}
'avx2_scan_min_i32' = avx2_scan_idem{i32, min}
'avx2_scan_max_i32' = avx2_scan_idem{i32, max}
# Associative scan
avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {