Initialized min/max-scan

This commit is contained in:
Marshall Lochbaum 2022-11-17 07:46:26 -05:00
parent 98d066c343
commit 5985e92834
2 changed files with 44 additions and 27 deletions

View File

@ -95,8 +95,9 @@ B scan_add_bool(B x, u64 ia) { // consumes x
return FL_SET(r, fl_asc|fl_squoze);
}
// min/max-scan
#if SINGELI
#define MINMAX_SCAN(T,NAME,C,I) avx2_scan_##NAME##_##T(xp, rp, ia);
#define MINMAX_SCAN(T,NAME,C,I) avx2_scan_##NAME##_init_##T(xp, rp, ia, I);
#else
#define MINMAX_SCAN(T,NAME,C,I) T c=I; for (usz i=0; i<ia; i++) { if (xp[i] C c)c=xp[i]; rp[i]=c; }
#endif
@ -113,8 +114,30 @@ B scan_add_bool(B x, u64 ia) { // consumes x
decG(x); return FL_SET(r, fl_##ORD);
B scan_min_num(B x, u8 xe, usz ia) { MINMAX(min,<,MAX,and,dsc) }
B scan_max_num(B x, u8 xe, usz ia) { MINMAX(max,>,MIN,or ,asc) }
#undef MM_CASE
#undef MINMAX
// Initialized: try to convert 𝕨 to type of 𝕩
// (could do better for out-of-range floats)
B shape_c2(B, B, B);
#define MM2_ICASE(T,N,C,I) \
case el_##T : { \
if (wv!=(T)wv) { if (wv C 0) { r=C2(shape,m_f64(ia),w); break; } else wv=I; } \
T* xp=T##any_ptr(x); T* rp; r=m_##T##arrv(&rp, ia); MINMAX_SCAN(T,N,C,wv); \
break; }
#define MINMAX2(NAME,C,INIT,BIT,BI,ORD) \
i32 wv=0; if (q_i32(w)) { wv=w.f; } else { x=taga(cpyF64Arr(x)); xe=el_f64; } \
B r; switch (xe) { default:UD; \
case el_bit: if (wv C BI) r=C2(shape,m_f64(ia),w); else return scan_##BIT(x, ia); break; \
MM2_ICASE(i8 ,NAME,C,I8_##INIT ) \
MM2_ICASE(i16,NAME,C,I16_##INIT) \
MM_CASE(i32,NAME,C,wv) \
MM_CASE(f64,NAME,C,w.f) \
} \
decG(x); return FL_SET(r, fl_##ORD);
static B scan2_min_num(B w, B x, u8 xe, usz ia) { MINMAX2(min,<,MAX,and,1,dsc) }
static B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0,asc) }
#undef MINMAX2
#undef MM_CASE
#undef MM_CASE2
#undef MINMAX_SCAN
B scan_c1(Md1D* d, B x) { B f = d->f;
@ -190,8 +213,11 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
ur xr = RNK(x); usz* xsh = SH(x); usz ia = IA(x);
B wf = getFillQ(w);
u8 xe = TI(x,elType);
if (xr==1 && q_i32(w) && elInt(xe) && isFun(f) && v(f)->flags) {
if (xr==1 && xe<=el_f64 && isFun(f) && v(f)->flags) {
u8 rtid = v(f)->flags-1;
if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊
if (rtid==n_ceil ) return scan2_max_num(w, x, xe, ia); // ⌈
if (!q_i32(w)) goto base;
i32 wv = o2iG(w);
if (xe==el_bit) {
u64* xp=bitarr_ptr(x);
@ -204,16 +230,6 @@ B scan_c2(Md1D* d, B w, B x) { B f = d->f;
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
}
if (rtid==n_floor) { // ⌊
if (xe==el_i8 && wv==(i8 )wv) { i8* xp=i8any_ptr (x); i8* rp; B r=m_i8arrv (&rp, ia); i8 c=wv; for (usz i=0; i<ia; i++) { if (xp[i]<c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i16 && wv==(i16)wv) { i16* xp=i16any_ptr(x); i16* rp; B r=m_i16arrv(&rp, ia); i16 c=wv; for (usz i=0; i<ia; i++) { if (xp[i]<c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i32 && wv==(i32)wv) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if (xp[i]<c)c=xp[i]; rp[i]=c; } decG(x); return r; }
}
if (rtid==n_ceil) { // ⌈
if (xe==el_i8 && wv==(i8 )wv) { i8* xp=i8any_ptr (x); i8* rp; B r=m_i8arrv (&rp, ia); i8 c=wv; for (usz i=0; i<ia; i++) { if (xp[i]>c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i16 && wv==(i16)wv) { i16* xp=i16any_ptr(x); i16* rp; B r=m_i16arrv(&rp, ia); i16 c=wv; for (usz i=0; i<ia; i++) { if (xp[i]>c)c=xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_i32 && wv==(i32)wv) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if (xp[i]>c)c=xp[i]; rp[i]=c; } decG(x); return r; }
}
if (rtid==n_ne) { // ≠
if (!q_ibit(wv)) { goto base; } bool c=wv;
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); u64* rp; B r=m_bitarrv(&rp, ia); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }

View File

@ -40,7 +40,7 @@ def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
avx2_scan_idem{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
# Within each lane, scan using shifts by powers of 2. First k elements
# when shifting by k don't need to change, so leave them alone.
def w = width{T}
@ -54,24 +54,25 @@ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = {
op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3<iota{8})}}}
}
scan_post{T, id, x, r, len, op, pre}
scan_post{T, init, x, r, len, op, pre}
}
avx2_scan_idem{T==f64, op, id}(x:*T, r:*T, len:u64) : void = {
avx2_scan_idem{T==f64, op}(x:*T, r:*T, len:u64, init:T) : void = {
def sh{s, a} = op{a, shuf{[4]u64, a, s}}
scan_post{T, id, x, r, len, op, {a}=>sh{4b1110,sh{4b2200,a}}}
scan_post{T, init, x, r, len, op, {a}=>sh{4b1110,sh{4b2200,a}}}
}
def avx2_scan_idem{T, op} = {
'avx2_scan_min_init_i8' = avx2_scan_idem{i8 , min}; 'avx2_scan_max_init_i8' = avx2_scan_idem{i8 , max}
'avx2_scan_min_init_i16' = avx2_scan_idem{i16, min}; 'avx2_scan_max_init_i16' = avx2_scan_idem{i16, max}
'avx2_scan_min_init_i32' = avx2_scan_idem{i32, min}; 'avx2_scan_max_init_i32' = avx2_scan_idem{i32, max}
'avx2_scan_min_init_f64' = avx2_scan_idem{f64, min}; 'avx2_scan_max_init_f64' = avx2_scan_idem{f64, max}
avx2_scan_idem_id{T, op}(x:*T, r:*T, len:u64) : void = {
def m = 1 << (width{T}-1)
avx2_scan_idem{T, op, (if (match{op,min}) m-1; else -m)}
def id = (if (match{op,min}) m-1; else -m)
avx2_scan_idem{T, op}(x, r, len, id)
}
'avx2_scan_min_i8' = avx2_scan_idem{i8 , min}
'avx2_scan_max_i8' = avx2_scan_idem{i8 , max}
'avx2_scan_min_i16' = avx2_scan_idem{i16, min}
'avx2_scan_max_i16' = avx2_scan_idem{i16, max}
'avx2_scan_min_i32' = avx2_scan_idem{i32, min}
'avx2_scan_max_i32' = avx2_scan_idem{i32, max}
'avx2_scan_min_f64' = avx2_scan_idem{f64, min, 'F64_MAX'}
'avx2_scan_max_f64' = avx2_scan_idem{f64, max, 'F64_MIN'}
'avx2_scan_min_i8' = avx2_scan_idem_id{i8 , min}; 'avx2_scan_max_i8' = avx2_scan_idem_id{i8 , max}
'avx2_scan_min_i16' = avx2_scan_idem_id{i16, min}; 'avx2_scan_max_i16' = avx2_scan_idem_id{i16, max}
'avx2_scan_min_i32' = avx2_scan_idem_id{i32, min}; 'avx2_scan_max_i32' = avx2_scan_idem_id{i32, max}
# Associative scan
avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {