diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index 1bfa337c..347a3f98 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -71,6 +71,13 @@ B shape_c1(B t, B x) { Arr* r = TI(x,slice)(x, 0, ia); arr_shVec(r); return taga(r); } +static B truncReshape(B x, usz xia, usz nia, ur nr, ShArr* sh) { // consumes all + B r; Arr* ra; + if (reusable(x) && xia==nia) { r = x; decSh(v(x)); ra = (Arr*)v(r); } + else { ra = TI(x,slice)(x, 0, nia); r = taga(ra); } + arr_shSetU(ra, nr, sh); + return r; +} B shape_c2(B t, B w, B x) { usz xia = isArr(x)? a(x)->ia : 1; usz nia = 1; @@ -155,11 +162,7 @@ B shape_c2(B t, B w, B x) { // goes to unit } else { if (nia <= xia) { - B r; Arr* ra; - if (reusable(x) && xia==nia) { r = REUSE(x); decSh(v(x)); ra = (Arr*)v(r); } - else { ra = TI(x,slice)(x, 0, nia); r = taga(ra); } - arr_shSetU(ra, nr, sh); - return r; + return truncReshape(x, xia, nia, nr, sh); } else { xf = getFillQ(x); if (xia<=1) { @@ -1008,8 +1011,8 @@ B select_ucw(B t, B o, B w, B x) { u8 we = TI(w,elType); u8 xe = TI(x,elType); u8 re = TI(rep,elType); - - if (we==el_i32) { + if (we<=el_i32) { + w = toI32Any(w); i32* wp = i32any_ptr(w); if (rere?xe:re; @@ -1090,6 +1093,24 @@ B select_ucw(B t, B o, B w, B x) { #undef FREE_CHECK } +static B shape_uc1_t(B r, usz ia) { + if (!isArr(r) || rnk(r)!=1 || a(r)->ia!=ia) thrM("𝔽⌾⥊: 𝔽 changed the shape of the argument"); + return r; +} +B shape_uc1(B t, B o, B x) { + if (!isArr(x) || rnk(x)==0) { + usz xia = isArr(x)? a(x)->ia : 1; + B r = c1(o, shape_c1(t, x)); + if (isArr(r)) shape_uc1_t(r, xia); + return shape_c2(t, emptyIVec(), r); + } + usz xia = a(x)->ia; + if (rnk(x)==1) return shape_uc1_t(c1(o, x), xia); + ur xr = rnk(x); + ShArr* sh = shObj(x); + ptr_inc(sh); + return truncReshape(shape_uc1_t(c1(o, shape_c1(t, x)), xia), xia, xia, xr, sh); +} void sfns_init() { c(BFn,bi_pick)->uc1 = pick_uc1; @@ -1097,4 +1118,5 @@ void sfns_init() { c(BFn,bi_pick)->ucw = pick_ucw; c(BFn,bi_slash)->ucw = slash_ucw; c(BFn,bi_select)->ucw = select_ucw; + c(BFn,bi_shape)->uc1 = shape_uc1; } diff --git a/src/core/derv.c b/src/core/derv.c index 75b91ba6..4fa0b334 100644 --- a/src/core/derv.c +++ b/src/core/derv.c @@ -1,4 +1,5 @@ #include "../core.h" +#include "../nfns.h" DEF_FREE(md1D) { dec(((Md1D*)x)->m1); dec(((Md1D*)x)->f); } DEF_FREE(md2D) { dec(((Md2D*)x)->m2); dec(((Md2D*)x)->f); dec(((Md2D*)x)->g); } @@ -61,13 +62,48 @@ static B md2D_uc1(B t, B o, B x) { return TI(m,m2_uc1)(m, o, f, g, x); } +static B toConstant(B x) { // doesn't consume x + if (!isCallable(x)) return inc(x); + if (v(x)->type == t_md1D) { + Md1D* d = c(Md1D,x); + B m1 = d->m1; + if (v(m1)->type==t_md1BI && v(m1)->flags==44) return inc(d->f); + } + return bi_N; +} +static NFnDesc* ucwWrapDesc; + +static B fork_uc1(B t, B o, B x) { + B f = toConstant(c(Fork, t)->f); + B g = c(Fork, t)->g; + B h = c(Fork, t)->h; + if (RARE(q_N(f) | !isFun(g) | !isFun(h))) { dec(f); return def_fn_uc1(t, o, x); } // flags check to not deconstruct builtins + B args[] = {g, o, f}; + B tmp = m_nfn(ucwWrapDesc, tag(args, RAW_TAG)); + B r = TI(h,fn_uc1)(h,tmp,x); + // f is consumed by the eventual ucwWrap call. this hopes that everything is nice and calls o only once, and within the under call, so any under interface must make they can't + dec(tmp); + return r; +} + +static B ucwWrap_c1(B t, B x) { + B* args = c(B, nfn_objU(t)); + B g = args[0]; + return TI(g,fn_ucw)(g, args[1], args[2], x); +} + void derv_init() { TIi(t_md1D,freeO) = md1D_freeO; TIi(t_md1D,freeF) = md1D_freeF; TIi(t_md1D,visit) = md1D_visit; TIi(t_md1D,print) = md1D_print; TIi(t_md1D,decompose) = md1D_decompose; - TIi(t_md2D,freeO) = md2D_freeO; TIi(t_md2D,freeF) = md2D_freeF; TIi(t_md2D,visit) = md2D_visit; TIi(t_md2D,print) = md2D_print; TIi(t_md2D,decompose) = md2D_decompose; TIi(t_md2D,fn_uc1) = md2D_uc1; + TIi(t_md2D,freeO) = md2D_freeO; TIi(t_md2D,freeF) = md2D_freeF; TIi(t_md2D,visit) = md2D_visit; TIi(t_md2D,print) = md2D_print; TIi(t_md2D,decompose) = md2D_decompose; TIi(t_md2H,freeO) = md2H_freeO; TIi(t_md2H,freeF) = md2H_freeF; TIi(t_md2H,visit) = md2H_visit; TIi(t_md2H,print) = md2H_print; TIi(t_md2H,decompose) = md2H_decompose; TIi(t_fork,freeO) = fork_freeO; TIi(t_fork,freeF) = fork_freeF; TIi(t_fork,visit) = fork_visit; TIi(t_fork,print) = fork_print; TIi(t_fork,decompose) = fork_decompose; TIi(t_atop,freeO) = atop_freeO; TIi(t_atop,freeF) = atop_freeF; TIi(t_atop,visit) = atop_visit; TIi(t_atop,print) = atop_print; TIi(t_atop,decompose) = atop_decompose; TIi(t_md1BI,m1_d) = m_md1D; TIi(t_md2BI,m2_d) = m_md2D; + TIi(t_md2D,fn_uc1) = md2D_uc1; // not in post so later init can utilize it } +void dervPost_init() { + ucwWrapDesc = registerNFn(m_str8l("(temporary function for ⌾)"), ucwWrap_c1, c2_bad); + TIi(t_fork,fn_uc1) = fork_uc1; +} \ No newline at end of file diff --git a/src/h.h b/src/h.h index 9d15ef9a..7424ecfd 100644 --- a/src/h.h +++ b/src/h.h @@ -147,6 +147,7 @@ static const u16 C32_TAG = 0b0111111111110001; // 7FF1 0111111111110001......... static const u16 TAG_TAG = 0b0111111111110010; // 7FF2 0111111111110010................nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn special value (0=nothing, 1=undefined var, 2=bad header; 3=optimized out; 4=error?; 5=no fill) static const u16 VAR_TAG = 0b0111111111110011; // 7FF3 0111111111110011ddddddddddddddddnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn variable reference static const u16 EXT_TAG = 0b0111111111110100; // 7FF4 0111111111110100ddddddddddddddddnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn extended variable reference +static const u16 RAW_TAG = 0b0111111111110101; // 7FF5 0111111111110101nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn raw 48 bits of data static const u16 MD1_TAG = 0b1111111111110010; // FFF2 1111111111110010ppppppppppppppppppppppppppppppppppppppppppppp000 1-modifier static const u16 MD2_TAG = 0b1111111111110011; // FFF3 1111111111110011ppppppppppppppppppppppppppppppppppppppppppppp000 2-modifier static const u16 FUN_TAG = 0b1111111111110100; // FFF4 1111111111110100ppppppppppppppppppppppppppppppppppppppppppppp000 function diff --git a/src/load.c b/src/load.c index 2b344d84..54698889 100644 --- a/src/load.c +++ b/src/load.c @@ -5,7 +5,7 @@ #include "ns.h" #include "builtins.h" -#define FOR_INIT(F) F(base) F(harr) F(mutF) F(fillarr) F(tyarr) F(hash) F(sfns) F(fns) F(arith) F(md1) F(md2) F(derv) F(comp) F(rtWrap) F(ns) F(nfn) F(sysfn) F(load) F(sysfnPost) +#define FOR_INIT(F) F(base) F(harr) F(mutF) F(fillarr) F(tyarr) F(hash) F(sfns) F(fns) F(arith) F(md1) F(md2) F(derv) F(comp) F(rtWrap) F(ns) F(nfn) F(sysfn) F(load) F(sysfnPost) F(dervPost) #define F(X) void X##_init(void); FOR_INIT(F) #undef F