a bunch of Scan fixes & improvements

This commit is contained in:
dzaima 2023-04-07 20:41:34 +03:00
parent b0d8bcb428
commit 3a7bce3aab

View File

@ -133,13 +133,25 @@ B shape_c2(B, B, B);
MM_CASE(f64,NAME,C,w.f) \
} \
decG(x); return FL_SET(r, fl_##ORD);
static B scan2_min_num(B w, B x, u8 xe, usz ia) { MINMAX2(min,<,MAX,and,1,dsc) }
static B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0,asc) }
SHOULD_INLINE B scan2_min_num(B w, B x, u8 xe, usz ia) { MINMAX2(min,<,MAX,and,1,dsc) }
SHOULD_INLINE B scan2_max_num(B w, B x, u8 xe, usz ia) { MINMAX2(max,>,MIN,or ,0,asc) }
#undef MINMAX2
#undef MM_CASE
#undef MM_CASE2
#undef MINMAX_SCAN
static B scan_lt(B x, u64 p, usz ia) {
u64* xp = bitarr_ptr(x);
u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia);
u64 m10 = 0x5555555555555555;
for (usz i=0; i<n; i++) {
u64 x = xp[i];
u64 c = (m10 & ~(x<<1)) & ~(p>>63);
rp[i] = p = x & (m10 ^ (x + c));
}
decG(x); return r;
}
B scan_c1(Md1D* d, B x) { B f = d->f;
if (isAtm(x) || RNK(x)==0) thrM("`: Argument cannot have rank 0");
ur xr = RNK(x);
@ -150,36 +162,33 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
if (xr==1 && xe<=el_f64 && isFun(f) && v(f)->flags) {
u8 rtid = v(f)->flags-1;
if (xe==el_bit) {
u64* xp=bitarr_ptr(x);
if (rtid==n_add ) return scan_add_bool(x, ia);
if (rtid==n_or | rtid==n_ceil ) return scan_or(x, ia);
if (rtid==n_and | rtid==n_mul | rtid==n_floor) return scan_and(x, ia);
if (rtid==n_ne ) return scan_ne(x, 0, ia);
if (rtid==n_lt) {
u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia);
u64 m10 = 0x5555555555555555;
u64 p = 0;
for (usz i=0; i<n; i++) {
u64 x = xp[i];
u64 c = (m10 & ~(x<<1)) & ~(p>>63);
rp[i] = p = x & (m10 ^ (x + c));
}
decG(x); return r;
}
if (rtid==n_add ) return scan_add_bool(x, ia); // +
if (rtid==n_or | rtid==n_ceil ) return scan_or(x, ia); // ∨⌈
if (rtid==n_and | rtid==n_mul | rtid==n_floor) return scan_and(x, ia); // ∧×⌊
if (rtid==n_ne ) return scan_ne(x, 0, ia); // ≠
if (rtid==n_lt) return scan_lt(x, 0, ia); // <
goto base;
}
if (rtid==n_add) { // +
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=0; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=0; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=0; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
B r; void* rp = m_tyarrv(&r, xe==el_f64? sizeof(f64) : sizeof(i32), ia, xe==el_f64? t_f64arr : t_i32arr);
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32 c=0; for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } ((i32*)rp)[i]=c; } decG(x); return r; }
if (xe==el_f64) { f64* xp=f64any_ptr(x); f64 c=0; for (usz i=0; i<ia; i++) { c+= xp[i]; ((f64*)rp)[i]=c; } decG(x); return r; }
}
if (rtid==n_floor) return scan_min_num(x, xe, ia); // ⌊
if (rtid==n_ceil ) return scan_max_num(x, xe, ia); // ⌈
if (rtid==n_ne) { // ≠
f64 x0 = IGetU(x,0).f; if (x0!=0 && x0!=1) goto base;
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); u64* rp; B r=m_bitarrv(&rp,ia); bool c=x0; rp[0]=c; for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); u64* rp; B r=m_bitarrv(&rp,ia); bool c=x0; rp[0]=c; for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); u64* rp; B r=m_bitarrv(&rp,ia); bool c=x0; rp[0]=c; for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (!elInt(xe)) goto base;
f64 x0 = IGetU(x,0).f;
if (!q_fbit(x0)) goto base;
u64* rp; B r = m_bitarrv(&rp, ia);
bool c = x0;
rp[0] = c;
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=1; i<ia; i++) { c = c!=xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
UD;
}
if (rtid==n_or) { x=num_squeezeChk(x); xe=TI(x,elType); if (xe==el_bit) return scan_or(x, ia); }
}
@ -204,33 +213,55 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
return withFill(r.b, xf);
}
B add_c2(B, B, B);
B scan_c2(Md1D* d, B w, B x) { B f = d->f;
if (isAtm(x) || RNK(x)==0) thrM("`: 𝕩 cannot have rank 0");
ur xr = RNK(x); usz* xsh = SH(x); usz ia = IA(x);
B wf = getFillQ(w);
u8 xe = TI(x,elType);
if (xr==1 && xe<=el_f64 && isFun(f) && v(f)->flags) {
if (xr==1 && elNum(xe) && isFun(f) && v(f)->flags && isF64(w)) {
u8 rtid = v(f)->flags-1;
if (rtid==n_floor) return scan2_min_num(w, x, xe, ia); // ⌊
if (rtid==n_ceil ) return scan2_max_num(w, x, xe, ia); // ⌈
if (!q_i32(w)) goto base;
i32 wv = o2iG(w);
if (xe==el_bit) {
u64* xp=bitarr_ptr(x);
if (rtid==n_add) { i32* rp; B r=m_i32arrv(&rp, ia); i64 c=wv; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); rp[i]=c; } decG(x); return r; }
if (rtid==n_ne) return scan_ne(x, -(u64)(q_ibit(wv)?wv:1&~*xp), ia);
goto base;
}
if (rtid==n_add) { // +
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); i32* rp; B r=m_i32arrv(&rp, ia); i32 c=wv; for (usz i=0; i<ia; i++) { if(addOn(c,xp[i]))goto base; rp[i]=c; } decG(x); return r; }
if (xe==el_f64) { f64 c=o2fG(w); f64* rp; B r=m_f64arrv(&rp, ia); f64* xp=f64any_ptr(x); for (usz i=0; i<ia; i++) { c+= xp[i]; rp[i]=c; } decG(x); return r; }
if (xe==el_bit) {
if (!q_i64(w)) goto base;
i64 wv = o2i64(w);
if (wv>=(1ULL<<53) || wv+(i64)ia >= (1ULL<<53)) goto base;
B t = scan_add_bool(x, ia);
return wv==0? t : C2(add, w, t);
}
if (!q_i32(w) || !elInt(xe)) goto base;
i32 c = o2iG(w);
i32* rp; B r = m_i32arrv(&rp, ia);
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { if (addOn(c,xp[i])) { decG(r); goto base; } rp[i]=c; } decG(x); return r; }
UD;
}
if (rtid==n_ne) { // ≠
if (!q_ibit(wv)) { goto base; } bool c=wv;
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); u64* rp; B r=m_bitarrv(&rp, ia); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); u64* rp; B r=m_bitarrv(&rp, ia); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); u64* rp; B r=m_bitarrv(&rp, ia); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
bool wBit = q_bit(w);
if (xe==el_bit) return scan_ne(x, -(u64)(wBit? o2bG(w) : 1&~*bitarr_ptr(x)), ia);
if (!wBit || !elInt(xe)) goto base;
bool c = o2bG(w);
u64* rp; B r = m_bitarrv(&rp, ia);
if (xe==el_i8 ) { i8* xp=i8any_ptr (x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i16) { i16* xp=i16any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
if (xe==el_i32) { i32* xp=i32any_ptr(x); for (usz i=0; i<ia; i++) { c^= xp[i]; bitp_set(rp,i,c); } decG(x); return r; }
UD;
}
if (xe==el_bit && q_bit(w)) {
// ⌊ & ⌈ handled above
if (rtid==n_or ) { if (!o2bG(w)) return scan_or(x, ia); B r = taga(arr_shVec(allOnes (ia))); decG(x); return r; } //
if (rtid==n_and | rtid==n_mul) { if ( o2bG(w)) return scan_and(x, ia); B r = taga(arr_shVec(allZeroes(ia))); decG(x); return r; } // ∧×
if (rtid==n_lt) return scan_lt(x, bitx(w), ia); // <
}
}
base:;