Move +⌈⌊ folds to function tables, unifying monadic and dyadic cases

This commit is contained in:
Marshall Lochbaum 2022-11-20 21:51:30 -05:00
parent ebfd002793
commit c977065c20

View File

@ -16,6 +16,51 @@ static i64 bit_diff(u64* x, u64 am) {
return r - (i64)(am/2);
}
// It's safe to sum a block of integers as long as the current total
// is far enough from +-1ull<<53 (and integer, in dyadic fold).
static const usz sum_small_max = 1<<16;
#define DEF_INT_SUM(T,W,M,A) \
static i64 sum_small_##T(void* xv, usz ia) { \
i##A s=0; for (usz i=0; i<ia; i++) s+=((T*)xv)[i]; return s; \
} \
static f64 sum_##T(void* xv, usz ia, f64 init) { \
usz b=1<<(M-W); i64 lim = (1ull<<53) - (1ull<<M); \
T* xp = xv; \
f64 r=init; i64 c=init; usz i0=ia; \
if (c == init) { \
while (i0>0 && -lim<=c && c<=lim) { \
usz e=i0; i0=(i0-1)&~(b-1); \
c+=sum_small_##T(xp+i0, e-i0); \
} \
r = c; \
} \
while (i0--) r+=xp[i0]; \
return r; \
}
DEF_INT_SUM(i8 ,8 ,32,32)
DEF_INT_SUM(i16,16,32,32)
DEF_INT_SUM(i32,32,52,64)
#undef DEF_SUM
static f64 sum_f64(void* xv, usz i, f64 r) {
while (i--) r += ((f64*)xv)[i];
return r;
}
static i64 (*const sum_small_fns[])(void*, usz) = { sum_small_i8, sum_small_i16, sum_small_i32 };
static f64 (*const sum_fns[])(void*, usz, f64) = { sum_i8, sum_i16, sum_i32, sum_f64 };
#define MIN_MAX(T,C) \
T* xp = xv; T r = xp[0]; \
for (usz i=1; i<ia; i++) if (xp[i] C r) r=xp[i]; \
return r;
#define DEF_MIN_MAX(T) \
static f64 min_##T(void* xv, usz ia) { MIN_MAX(T,<) } \
static f64 max_##T(void* xv, usz ia) { MIN_MAX(T,>) }
DEF_MIN_MAX(i8) DEF_MIN_MAX(i16) DEF_MIN_MAX(i32) DEF_MIN_MAX(f64)
#undef DEF_MIN_MAX
#undef MIN_MAX
static f64 (*const min_fns[])(void*, usz) = { min_i8, min_i16, min_i32, min_f64 };
static f64 (*const max_fns[])(void*, usz) = { max_i8, max_i16, max_i32, max_f64 };
B fold_c1(Md1D* d, B x) { B f = d->f;
if (isAtm(x) || RNK(x)!=1) thrF("´: Argument must be a list (%H ≡ ≢𝕩)", x);
usz ia = IA(x);
@ -45,23 +90,10 @@ B fold_c1(Md1D* d, B x) { B f = d->f;
}
if (rtid==n_add) { // +
void *xv = tyany_ptr(x);
f64 r;
#define CASE_INT(T,M,A) case el_i##T: { \
usz b=1<<(M-T); i64 lim = (1ull<<53) - (1ull<<M); \
i##T* xp = xv; i64 c=0; \
usz i0=ia; while (i0>0 && -lim<=c && c<=lim) { \
usz e=i0; i0=(i0-1)&~(b-1); \
i##A s=0; for (usz i=i0; i<e; i++) s+=xp[i]; c+=s; \
} \
r = c; while (i0--) r+= xp[i0]; \
break; }
switch (xe) { default: UD;
CASE_INT(8 ,32,32)
CASE_INT(16,32,32)
CASE_INT(32,52,64)
case el_f64: { r=0; for (usz i=ia; i--; ) c+=((f64*)xp)[i]; break; }
}
#undef CASE_INT
bool small = xe<=el_i32 & ia<=sum_small_max;
u8 sel = xe - el_i8;
f64 r = small ? sum_small_fns[sel](xv, ia)
: sum_fns[sel](xv, ia, 0);
decG(x); return m_f64(r);
}
if (rtid==n_mul | rtid==n_and) { // ×/∧
@ -70,18 +102,8 @@ B fold_c1(Md1D* d, B x) { B f = d->f;
if (xe==el_i32) { i32* xp = i32any_ptr(x); i32 c=1; for (usz i=ia; i--; ) if (mulOn(c,xp[i]))goto base; decG(x); return m_i32(c); }
if (xe==el_f64) { f64* xp = f64any_ptr(x); f64 c=1; for (usz i=ia; i--; ) c*= xp[i]; decG(x); return m_f64(c); }
}
#define CASE(T,C) case el_##T: { \
T* xp = xv; T c = xp[0]; \
for (usz i=0; i<ia; i++) if (xp[i] C c) c=xp[i]; \
r = c; break; }
#define FC(C) \
void *xv = tyany_ptr(x); f64 r; \
switch(xe) { default:UD; CASE(i8,C) CASE(i16,C) CASE(i32,C) CASE(f64,C) } \
decG(x); return m_f64(r);
if (rtid==n_floor) { FC(<) } // ⌊
if (rtid==n_ceil ) { FC(>) } // ⌈
#undef FC
#undef CASE
if (rtid==n_floor) { f64 r=min_fns[xe-el_i8](tyany_ptr(x), ia); decG(x); return m_f64(r); } // ⌊
if (rtid==n_ceil ) { f64 r=max_fns[xe-el_i8](tyany_ptr(x), ia); decG(x); return m_f64(r); } // ⌈
if (rtid==n_or) { //
if (xe==el_i8 ) { i8* xp = i8any_ptr (x); bool r=0; for (usz i=0; i<ia; i++) { i8 c=xp[i]; if (c!=0&&c!=1)goto base; r|=c; } decG(x); return m_i32(r); }
if (xe==el_i16) { i16* xp = i16any_ptr(x); bool r=0; for (usz i=0; i<ia; i++) { i16 c=xp[i]; if (c!=0&&c!=1)goto base; r|=c; } decG(x); return m_i32(r); }
@ -125,25 +147,17 @@ B fold_c2(Md1D* d, B w, B x) { B f = d->f;
goto base;
}
if (rtid==n_add) { // +
if (xe==el_i8 ) { i8* xp = i8any_ptr (x); i64 c=wi; for (usz i=0; i<ia; i++) c+=xp[i]; decG(x); return m_f64(c); }
if (xe==el_i16) { i16* xp = i16any_ptr(x); i32 c=wi; for (usz i=0; i<ia; i++) if (addOn(c,xp[i]))goto base; decG(x); return m_i32(c); }
if (xe==el_i32) { i32* xp = i32any_ptr(x); i32 c=wi; for (usz i=0; i<ia; i++) if (addOn(c,xp[i]))goto base; decG(x); return m_i32(c); }
u8 sel = xe - el_i8;
f64 r = sum_fns[sel](tyany_ptr(x), ia, wi);
decG(x); return m_f64(r);
}
if (rtid==n_mul | rtid==n_and) { // ×/∧
if (xe==el_i8 ) { i8* xp = i8any_ptr (x); i32 c=wi; for (usz i=ia; i--; ) if (mulOn(c,xp[i]))goto base; decG(x); return m_i32(c); }
if (xe==el_i16) { i16* xp = i16any_ptr(x); i32 c=wi; for (usz i=ia; i--; ) if (mulOn(c,xp[i]))goto base; decG(x); return m_i32(c); }
if (xe==el_i32) { i32* xp = i32any_ptr(x); i32 c=wi; for (usz i=ia; i--; ) if (mulOn(c,xp[i]))goto base; decG(x); return m_i32(c); }
}
if (rtid==n_floor) { // ⌊
if (xe==el_i8 ) { i8* xp = i8any_ptr (x); i32 c=wi; for (usz i=0; i<ia; i++) if (xp[i]<c) c=xp[i]; decG(x); return m_i32(c); }
if (xe==el_i16) { i16* xp = i16any_ptr(x); i32 c=wi; for (usz i=0; i<ia; i++) if (xp[i]<c) c=xp[i]; decG(x); return m_i32(c); }
if (xe==el_i32) { i32* xp = i32any_ptr(x); i32 c=wi; for (usz i=0; i<ia; i++) if (xp[i]<c) c=xp[i]; decG(x); return m_i32(c); }
}
if (rtid==n_ceil) { // ⌈
if (xe==el_i8 ) { i8* xp = i8any_ptr (x); i32 c=wi; for (usz i=0; i<ia; i++) if (xp[i]>c) c=xp[i]; decG(x); return m_i32(c); }
if (xe==el_i16) { i16* xp = i16any_ptr(x); i32 c=wi; for (usz i=0; i<ia; i++) if (xp[i]>c) c=xp[i]; decG(x); return m_i32(c); }
if (xe==el_i32) { i32* xp = i32any_ptr(x); i32 c=wi; for (usz i=0; i<ia; i++) if (xp[i]>c) c=xp[i]; decG(x); return m_i32(c); }
}
if (rtid==n_floor) { f64 r=wi; if (ia>0) { f64 m=min_fns[xe-el_i8](tyany_ptr(x), ia); if (m<r) r=m; } decG(x); return m_f64(r); } // ⌊
if (rtid==n_ceil ) { f64 r=wi; if (ia>0) { f64 m=max_fns[xe-el_i8](tyany_ptr(x), ia); if (m>r) r=m; } decG(x); return m_f64(r); } // ⌈
if (rtid==n_or && (wi&1)==wi) { //
if (xe==el_i8 ) { i8* xp = i8any_ptr (x); bool q=wi; for (usz i=0; i<ia; i++) { i8 c=xp[i]; if (c!=0&&c!=1)goto base; q|=c; } decG(x); return m_i32(q); }
if (xe==el_i16) { i16* xp = i16any_ptr(x); bool q=wi; for (usz i=0; i<ia; i++) { i16 c=xp[i]; if (c!=0&&c!=1)goto base; q|=c; } decG(x); return m_i32(q); }