Merge pull request #62 from mlochbaum/math

Various •math functions
This commit is contained in:
dzaima 2022-11-24 01:18:32 +02:00 committed by GitHub
commit ae7ac647a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 193 additions and 57 deletions

View File

@ -17,7 +17,10 @@
/*internal.c*/M(itype,"•internal.Type") M(elType,"•internal.ElType") M(refc,"•internal.Refc") M(isPure,"•internal.IsPure") A(info,"•internal.Info") M(heapDump,"•internal.HeapDump") \
/*internal.c*/M(squeeze,"•internal.Squeeze") M(deepSqueeze,"•internal.DeepSqueeze") D(eequal,"•internal.EEqual") A(internalTemp,"•internal.Temp") \
/*internal.c*/D(variation,"•internal.Variation") A(listVariations,"•internal.ListVariations") M(clearRefs,"•internal.ClearRefs") M(unshare,"•internal.Unshare") \
/* arithm.c*/M(sin,"•math.Sin") M(cos,"•math.Cos") M(tan,"•math.Tan") M(asin,"•math.Asin") M(acos,"•math.Acos") M(atan,"•math.Atan") D(atan2,"•math.Atan2")
/* arithm.c*/M(sin,"•math.Sin") M(cos,"•math.Cos") M(tan,"•math.Tan") M(asin,"•math.Asin") M(acos,"•math.Acos") M(atan,"•math.Atan") D(atan2,"•math.Atan2") D(hypot,"•math.Hypot") \
/* arithm.c*/M(sinh,"•math.Sinh") M(cosh,"•math.Cosh") M(tanh,"•math.Tanh") M(asinh,"•math.Asinh") M(acosh,"•math.Acosh") M(atanh,"•math.Atanh") \
/* arithm.c*/M(cbrt,"•math.Cbrt") M(log2,"•math.Log2") M(log10,"•math.Log10") M(log1p,"•math.Log1p") M(expm1,"•math.Expm1") M(fact,"•math.Fact") D(comb,"•math.Comb") M(logfact,"•math.LogFact") \
/* arithm.c*/M(erf,"•math.Erf") M(erfc,"•math.ErfC") D(gcd,"•math.GCD") D(lcm,"•math.LCM") M(sum,"•math.Sum")
#define FOR_PM1(A,M,D) \
/*md1.c*/A(tbl,"") A(each,"¨") A(fold,"´") A(scan,"`") A(const,"˙") A(swap,"˜") A(cell,"˘") A(insert,"˝") \

View File

@ -360,10 +360,83 @@ AR_F_SCALAR("|", stile, pfmod(x.f, w.f))
AR_F_SCALAR("⋆⁼",log , log(x.f)/log(w.f))
#undef AR_F_SCALAR
B atan2_c2(B t, B w, B x) {
if (isNum(w) && isNum(x)) return m_f64(atan2(x.f, w.f));
P2(atan2)
thrM("•math.Atan2: Unexpected argument types");
static f64 comb_nat(f64 k, f64 n) {
assert(k>=0 && n>=2*k);
if (k > 514) return INFINITY;
f64 p = 1;
for (usz i=0; i<(usz)k; i++) {
p*= (n-i) / (k-i);
if (p == INFINITY) return p;
}
return round(p);
}
static f64 comb(f64 k, f64 n) { // n choose k
f64 j = n - k; // j+k == n
bool jint = j == round(j);
if (k == round(k)) {
if (jint) {
if (k<j) { f64 t=k; k=j; j=t; } // Now j<k
if (n >= 0) {
return j<0? 0 : comb_nat(j, n);
} else {
if (k<0) return 0;
f64 l = -1-n; // l+k == -1-j
f64 r = comb_nat(k<l? k : l, -1-j);
return k<(1ull<<53) && ((i64)k&1)? -r : r;
}
}
if (k < 0) return 0;
} else if (jint) {
if (j < 0) return 0;
}
return exp(lgamma(n+1) - lgamma(k+1) - lgamma(j+1));
}
#define MATH(n,N) \
B n##_c2(B t, B w, B x) { \
if (isNum(w) && isNum(x)) return m_f64(n(x.f, w.f)); \
P2(n) \
thrM("•math." #N ": Unexpected argument types"); \
}
MATH(atan2,Atan2) MATH(hypot,Hypot) MATH(comb,Comb)
#undef MATH
static u64 gcd_u64(u64 a, u64 b) {
if (a == 0) return b;
if (b == 0) return a;
u8 az = CTZ(a);
u8 bz = CTZ(b);
u8 sh = az<bz? az : bz;
b >>= bz;
while (a > 0) {
a >>= az;
u64 d = b - a;
az = CTZ(d);
b = b<a? b : a;
a = b<a? -d : d;
}
return b << sh;
}
static u64 lcm_u64(u64 a, u64 b) {
if (a==0 | b==0) return 0;
return (a / gcd_u64(a, b)) * b;
}
B gcd_c2(B t, B w, B x) {
if (isNum(w) && isNum(x)) {
if (!q_u64(w) || !q_u64(x)) thrM("•math.GCD: Inputs other than natural numbers not yet supported");
return m_f64(gcd_u64(o2u64G(w), o2u64G(x)));
}
P2(gcd)
thrM("•math.GCD: Unexpected argument types");
}
B lcm_c2(B t, B w, B x) {
if (isNum(w) && isNum(x)) {
if (!q_u64(w) || !q_u64(x)) thrM("•math.LCM: Inputs other than natural numbers not yet supported");
return m_f64(lcm_u64(o2u64G(w), o2u64G(x)));
}
P2(gcd)
thrM("•math.GCD: Unexpected argument types");
}
#undef P2

View File

@ -22,11 +22,12 @@ B bit_negate(B x) { // consumes
return r;
}
#define GC1i(SYMB,NAME,FEXPR,IBAD,IEXPR,BX,SQF) B NAME##_c1(B t, B x) { \
#define GC1i(SYMB,NAME,FEXPR,IBAD,IEXPR,SQF,TMIN,RMIN) B NAME##_c1(B t, B x) { \
if (isF64(x)) { f64 v = x.f; return m_f64(FEXPR); } \
if (RARE(!isArr(x))) thrM(SYMB ": Expected argument to be a number"); \
u8 xe = TI(x,elType); \
i64 sz = IA(x); BX \
if (xe<=TMIN) return RMIN; \
i64 sz = IA(x); \
if (xe==el_i8) { i8 MAX=I8_MAX; i8 MIN=I8_MIN; i8* xp=i8any_ptr(x); i8* rp; B r=m_i8arrc(&rp,x); \
for (i64 i = 0; i < sz; i++) { i8 v = xp[i]; if (RARE(IBAD)) { decG(r); goto base; } rp[i] = IEXPR; } \
decG(x); (void)MIN;(void)MAX; return r; \
@ -56,14 +57,12 @@ B add_c1(B t, B x) {
return x;
}
GC1i("-", sub, -v, v== MIN, -v, {}, 0) // change icond to v==-v to support ¯0 (TODO that won't work for i8/i16)
GC1i("|", stile, fabs(v), v== MIN, v<0?-v:v,{}, 0)
GC1i("", floor, floor(v), 0, v, {}, 1)
GC1i("", ceil, ceil(v), 0, v, {}, 1)
GC1i("×", mul, v==0?0:v>0?1:-1, 0, v==0?0:v>0?1:-1,{}, 1)
GC1i("¬", not, 1-v, v<=-MAX, 1-v, {
if(xe==el_bit) return bit_negate(x);
}, 0)
GC1i("-", sub, -v, v== MIN, -v, 0, el_bit, bit_sel(x,m_f64(0),m_f64(-1))) // change icond to v==-v to support ¯0 (TODO that won't work for i8/i16)
GC1i("|", stile, fabs(v), v== MIN, v<0?-v:v,0, el_bit, x)
GC1i("", floor, floor(v), 0, v, 1, el_i32, x)
GC1i("", ceil, ceil(v), 0, v, 1, el_i32, x)
GC1i("×", mul, v==0?0:v>0?1:-1, 0,v==0?0:v>0?1:-1,1, el_bit, x)
GC1i("¬", not, 1-v, v<=-MAX, 1-v, 0, el_bit, bit_negate(x))
#define GC1f(N, F, MSG) B N##_c1(B t, B x) { \
if (isF64(x)) { f64 xv=o2fG(x); return m_f64(F); } \
@ -89,15 +88,20 @@ GC1f( div, 1/xv, "÷: Getting reciprocal of non-number")
GC1f(root, sqrt(xv), "√: Getting square root of non-number")
#undef GC1f
f64 fact(f64 x) { return tgamma(x+1); }
f64 logfact(f64 x) { return lgamma(x+1); }
#define P1(N) { if(isArr(x)) { SLOW1("arithm " #N, x); return arith_recm(N##_c1, x); } }
B pow_c1(B t, B x) { if (isF64(x)) return m_f64( exp(x.f)); P1( pow); thrM("⋆: Getting exp of non-number"); }
B log_c1(B t, B x) { if (isF64(x)) return m_f64( log(x.f)); P1( log); thrM("⋆⁼: Getting log of non-number"); }
B sin_c1(B t, B x) { if (isF64(x)) return m_f64( sin(x.f)); P1( sin); thrM("•math.Sin: Argument contained non-number"); }
B cos_c1(B t, B x) { if (isF64(x)) return m_f64( cos(x.f)); P1( cos); thrM("•math.Cos: Argument contained non-number"); }
B tan_c1(B t, B x) { if (isF64(x)) return m_f64( tan(x.f)); P1( tan); thrM("•math.Tan: Argument contained non-number"); }
B asin_c1(B t, B x) { if (isF64(x)) return m_f64( asin(x.f)); P1( asin); thrM("•math.Asin: Argument contained non-number"); }
B acos_c1(B t, B x) { if (isF64(x)) return m_f64( acos(x.f)); P1( acos); thrM("•math.Acos: Argument contained non-number"); }
B atan_c1(B t, B x) { if (isF64(x)) return m_f64( atan(x.f)); P1( atan); thrM("•math.Atan: Argument contained non-number"); }
#define MATH(n,N) \
B n##_c1(B t, B x) { if (isF64(x)) return m_f64(n(x.f)); P1(n); thrM("•math." #N ": Argument contained non-number"); }
MATH(cbrt,Cbrt) MATH(log2,Log2) MATH(log10,Log10) MATH(log1p,Log1p) MATH(expm1,Expm1)
MATH(fact,Fact) MATH(logfact,LogFact) MATH(erf,Erf) MATH(erfc,ErfC)
#define TRIG(n,N) MATH(n,N) MATH(a##n,A##n) MATH(n##h,N##h) MATH(a##n##h,A##n##h)
TRIG(sin,Sin) TRIG(cos,Cos) TRIG(tan,Tan)
#undef TRIG
#undef MATH
#undef P1
B lt_c1(B t, B x) { return m_atomUnit(x); }
@ -109,8 +113,8 @@ static B mathNS;
B getMathNS() {
if (mathNS.u == 0) {
#define F(X) inc(bi_##X),
Body* d = m_nnsDesc("sin","cos","tan","asin","acos","atan","atan2");
mathNS = m_nns(d, F(sin)F(cos)F(tan)F(asin)F(acos)F(atan)F(atan2));
Body* d = m_nnsDesc("sin","cos","tan","asin","acos","atan","atan2","sinh","cosh","tanh","asinh","acosh","atanh","cbrt","log2","log10","log1p","expm1","hypot","fact","logfact","erf","erfc","comb","gcd","lcm","sum");
mathNS = m_nns(d, F(sin)F(cos)F(tan)F(asin)F(acos)F(atan)F(atan2)F(sinh)F(cosh)F(tanh)F(asinh)F(acosh)F(atanh)F(cbrt)F(log2)F(log10)F(log1p)F(expm1)F(hypot)F(fact)F(logfact)F(erf)F(erfc)F(comb)F(gcd)F(lcm)F(sum));
#undef F
gc_add(mathNS);
}
@ -122,12 +126,17 @@ void arith_init() {
c(BFn,bi_mul)->ident = c(BFn,bi_div)->ident = c(BFn,bi_and)->ident = c(BFn,bi_eq)->ident = c(BFn,bi_ge)->ident = c(BFn,bi_pow)->ident = c(BFn,bi_not)->ident = m_i32(1);
c(BFn,bi_floor)->ident = m_f64(1.0/0.0);
c(BFn,bi_ceil )->ident = m_f64(-1.0/0.0);
#define INVERSE_PAIR(F,G) \
c(BFn,bi_##F)->im = G##_c1; \
c(BFn,bi_##G)->im = F##_c1;
c(BFn,bi_sub)->im = sub_c1;
c(BFn,bi_sin)->im = asin_c1;
c(BFn,bi_cos)->im = acos_c1;
c(BFn,bi_tan)->im = atan_c1;
c(BFn,bi_asin)->im = sin_c1;
c(BFn,bi_acos)->im = cos_c1;
c(BFn,bi_atan)->im = tan_c1;
INVERSE_PAIR(sin, asin)
INVERSE_PAIR(cos, acos)
INVERSE_PAIR(tan, atan)
INVERSE_PAIR(sinh, asinh)
INVERSE_PAIR(cosh, acosh)
INVERSE_PAIR(tanh, atanh)
INVERSE_PAIR(expm1, log1p)
#undef INVERSE_PAIR
}

View File

@ -9,6 +9,8 @@
// COULD implement fast numeric -´
// on boolean-valued integers, stopping at 1
// •math.Sum: +´ with faster and more precise SIMD code for i32, f64
#include "../core.h"
#include "../builtins.h"
@ -66,6 +68,40 @@ static f64 sum_f64(void* xv, usz i, f64 r) {
static i64 (*const sum_small_fns[])(void*, usz) = { sum_small_i8, sum_small_i16, sum_small_i32 };
static f64 (*const sum_fns[])(void*, usz, f64) = { sum_i8, sum_i16, sum_i32, sum_f64 };
B sum_c1(B t, B x) {
if (isAtm(x) || RNK(x)!=1) thrF("•math.Sum: Argument must be a list (%H ≡ ≢𝕩)", x);
usz ia = IA(x);
if (ia==0) return m_f64(0);
u8 xe = TI(x,elType);
if (!elNum(xe)) {
x = any_squeeze(x); xe = TI(x,elType);
if (!elNum(xe)) thrF("•math.Sum: Argument elements must be numbers", x);
}
f64 r;
void* xv = tyany_ptr(x);
if (xe == el_bit) {
r = bit_sum(xv, ia);
} else if (xe <= el_i32) {
u8 sel = xe - el_i8;
i64 s = 0; r = 0;
i64 m = 1ull<<48;
usz b = sum_small_max;
for (usz i=0; i<ia; i+=b) {
s += sum_small_fns[sel]((u8*)xv + (i<<sel), ia-i<b? ia-i : b);
if (s >= m) { r+=m; s-=m; }
if (s <= -m) { r-=m; s+=m; }
}
r += s;
} else {
#if SINGELI
r = avx2_sum_f64(xv, ia);
#else
r=0; for (usz i=0; i<ia; i++) r+=((f64*)xv)[i];
#endif
}
decG(x); return m_f64(r);
}
// Try to keep to i32 product, go to f64 on overflow or non-i32 initial
#define DEF_INT_PROD(T) \
static NOINLINE f64 prod_##T(void* xv, usz i, f64 init) { \

View File

@ -1551,7 +1551,7 @@ u32* dsv_text[] = {
U"•file.MapBytes",U"•file.Modified",U"•file.Name",U"•file.Parent",U"•file.Remove",U"•file.Rename",U"•file.Size",U"•file.Type",
U"•internal.ClearRefs",U"•internal.DeepSqueeze",U"•internal.EEqual",U"•internal.ElType",U"•internal.HeapDump",U"•internal.Info",U"•internal.IsPure",U"•internal.ListVariations",U"•internal.Refc",U"•internal.Squeeze",U"•internal.Temp",U"•internal.Type",U"•internal.Unshare",U"•internal.Variation",
U"•math.Acos",U"•math.Asin",U"•math.Atan",U"•math.Atan2",U"•math.Cos",U"•math.Sin",U"•math.Tan",
U"•math.Acos",U"•math.Acosh",U"•math.Asin",U"•math.Asinh",U"•math.Atan",U"•math.Atan2",U"•math.Atanh",U"•math.Cbrt",U"•math.Comb",U"•math.Cos",U"•math.Cosh",U"•math.Erf",U"•math.ErfC",U"•math.Expm1",U"•math.Fact",U"•math.GCD",U"•math.Hypot",U"•math.LCM",U"•math.Log10",U"•math.Log1p",U"•math.Log2",U"•math.LogFact",U"•math.Sin",U"•math.Sinh",U"•math.Sum",U"•math.Tan",U"•math.Tanh",
U"•rand.Deal",U"•rand.Range",U"•rand.Subset",
U"•term.CharB",U"•term.CharN",U"•term.ErrRaw",U"•term.Flush",U"•term.OutRaw",U"•term.RawMode",
NULL

View File

@ -14,11 +14,11 @@ void gc_addFn(vfn f) {
gc_roots[gc_rootSz++] = f;
}
Value* gc_rootObjs[256];
Value* gc_rootObjs[512];
u32 gc_rootObjSz;
void gc_add(B x) {
assert(isVal(x));
if (gc_rootObjSz>=256) err("Too many GC root objects");
if (gc_rootObjSz>=512) err("Too many GC root objects");
gc_rootObjs[gc_rootObjSz++] = v(x);
}

View File

@ -61,9 +61,9 @@ def iota{T & w256{T,16}} = make{T,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
def iota{T & w256{T,8}} = make{T,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
# bit arith
def __xor{a:T, b:T & w256i{T}} = T ~~ emit{[8]f32, '_mm256_xor_ps', v2f{a}, v2f{b}}
def __and{a:T, b:T & w256i{T}} = T ~~ emit{[8]f32, '_mm256_and_ps', v2f{a}, v2f{b}}
def __or {a:T, b:T & w256i{T}} = T ~~ emit{[8]f32, '_mm256_or_ps', v2f{a}, v2f{b}}
def __xor{a:T, b:T & w256{T}} = T ~~ emit{[8]f32, '_mm256_xor_ps', v2f{a}, v2f{b}}
def __and{a:T, b:T & w256{T}} = T ~~ emit{[8]f32, '_mm256_and_ps', v2f{a}, v2f{b}}
def __or {a:T, b:T & w256{T}} = T ~~ emit{[8]f32, '_mm256_or_ps', v2f{a}, v2f{b}}
def __not{a:T & w256u{T}} = a ^ broadcast{T, ~cast{eltype{T},0}}
# float comparison
@ -132,4 +132,4 @@ def getmask{x:T & w256{T, 64}} = emit{u8, '_mm256_movemask_pd', v2d{x}}
def andIsZero{x:T, y:T & w256i{T}} = emit{u1, '_mm256_testz_si256', x, y}
def any{x:T & w256i{T}} = getmask{x} != 0 # assumes elements of x all have equal bits (avx2 utilizes this for 16 bits)
def all{x:T & w256i{T}} = getmask{x} == (1<<vcount{T})-1 # same assumption
def anyneg{x:T & w256s{T}} = getmask{x}!=0
def anyneg{x:T & w256s{T}} = getmask{x}!=0

View File

@ -4,6 +4,23 @@ include './avx'
include './avx2'
include './mask'
def opsh{op}{v:[4]f64, perm} = op{v, shuf{[4]u64, v, perm}}
def mix{op, v:[4]f64} = { def sh=opsh{op}; sh{sh{v, 4b2301}, 4b1032} }
def reduce_pairwise{op, plog, x:*T, len, init:T} = {
# Pairwise combination to shorten dependency chains
def pairwise{p, i, k} = (if (k==0) { load{p,i} } else {
def l = k-1
op{pairwise{p, i , l},
pairwise{p, i+(1<<l), l}}
})
f:= len >> plog
r:= init
@for (i to f) r = op{r, pairwise{x+(i<<plog), 0, plog}}
@for (x over i from f<<plog to len) r = op{r, x}
r
}
fold_idem{T==f64, op}(x:*T, len:u64) : T = {
def step = 256/width{T}
def V = [step]T
@ -14,28 +31,26 @@ fold_idem{T==f64, op}(x:*T, len:u64) : T = {
assert{len > 0}
r = load{xv}
if (len > 1) {
if (len > 2) r = op{r, shuf{[4]u64, r, 4b2222}}
r = op{r, shuf{[4]u64, r, 4b1111}}
if (len > 2) r = opsh{op}{r, 4b2222}
r = opsh{op}{r, 4b1111}
}
} else {
# Pairwise combination to shorten dependency chains
def pairwise{p, i, k} = {
def l = k-1
op{pairwise{p, i , l},
pairwise{p, i+(1<<l), l}}
}
def pairwise{p, i, k==0} = load{p, i}
def pk = 2 # Combine 1<<pk values in a step
r = load{*V ~~ (x+len-step)}
e:= (len-1)/step
f:= e >> pk
@for (i to f) r = op{r, pairwise{xv+(i<<pk), 0, pk}}
@for (xv over i from f<<pk to e) r = op{r, xv}
r = op{r, shuf{[4]u64, r, 4b2301}}
r = op{r, shuf{[4]u64, r, 4b1032}}
i:= load{*V ~~ (x+len-step)}
r = mix{op, reduce_pairwise{op, 2, xv, (len-1)/step, i}}
}
extract{r, 0}
}
'avx2_fold_min_f64' = fold_idem{f64,min}
'avx2_fold_max_f64' = fold_idem{f64,max}
fold_assoc_0{T==f64, op}(x:*T, len:u64) : T = {
def step = 256/width{T}
def V = [step]T
xv:= *V ~~ x
e:= len / step
i:= load{xv, e} & (V~~maskOf{V, len % step})
r:= reduce_pairwise{op, 2, xv, e, i}
extract{mix{op, r}, 0}
}
'avx2_sum_f64' = fold_assoc_0{f64,+}

View File

@ -59,9 +59,9 @@ def iota{T & w128{T,16}} = make{T,0,1,2,3,4,5,6,7}
def iota{T & w128{T,8}} = make{T,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
# bit arith
def __xor{a:T, b:T & w128i{T}} = T ~~ emit{[4]f32, '_mm_xor_ps', v2f{a}, v2f{b}}
def __and{a:T, b:T & w128i{T}} = T ~~ emit{[4]f32, '_mm_and_ps', v2f{a}, v2f{b}}
def __or {a:T, b:T & w128i{T}} = T ~~ emit{[4]f32, '_mm_or_ps', v2f{a}, v2f{b}}
def __xor{a:T, b:T & w128{T}} = T ~~ emit{[4]f32, '_mm_xor_ps', v2f{a}, v2f{b}}
def __and{a:T, b:T & w128{T}} = T ~~ emit{[4]f32, '_mm_and_ps', v2f{a}, v2f{b}}
def __or {a:T, b:T & w128{T}} = T ~~ emit{[4]f32, '_mm_or_ps', v2f{a}, v2f{b}}
def __not{a:T & w128u{T}} = a ^ broadcast{T, ~cast{eltype{T},0}}
# signed comparison