diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index fa69a445..2038e10c 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -4,6 +4,19 @@ #include "../utils/builtins.h" #include "../utils/talloc.h" +static Arr* take_impl(usz ria, B x) { // consumes x; returns vā†‘ā„Šš•© without set shape; v is non-negative + usz xia = a(x)->ia; + if (ria>xia) { + B xf = getFillE(x); + MAKE_MUT(r, ria); mut_init(r, TI(x,elType)); + mut_copyG(r, 0, x, 0, xia); + mut_fill(r, xia, xf, ria-xia); + dec(x); dec(xf); + return mut_fp(r); + } else { + return TI(x,slice)(x,0,ria); + } +} B shape_c1(B t, B x) { if (isAtm(x)) { @@ -48,7 +61,6 @@ B shape_c1(B t, B x) { B shape_c2(B t, B w, B x) { usz xia = isArr(x)? a(x)->ia : 1; usz nia; - bool fill = false; ur nr; ShArr* sh; if (isF64(w)) { @@ -96,6 +108,7 @@ B shape_c2(B t, B w, B x) { i64 div = xia/tot; i64 mod = xia%tot; usz item; + bool fill = false; if (unkInd == 52) { if (mod!=0) thrM("ℊ: Shape must be exact when reshaping with ∘"); item = div; @@ -111,6 +124,12 @@ B shape_c2(B t, B w, B x) { tot*= item; if (tot > USZ_MAX) thrM("ℊ: Result too large"); nia = tot; + if (fill) { + Arr* a = take_impl(nia, x); + arr_shVec(a); + x = taga(a); + xia = nia; + } } else nia = tot; } } @@ -147,25 +166,15 @@ B shape_c2(B t, B w, B x) { i64 div = nia/xia; i64 mod = nia%xia; for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia); - if (fill && mod && noFill(xf)) thrM("ℊ: š•© had no fill element"); - if (fill) mut_fill(m, div*xia, xf, mod); - else mut_copyG(m, div*xia, x, 0, mod); + mut_copyG(m, div*xia, x, 0, mod); dec(x); Arr* ra = mut_fp(m); arr_shSetU(ra, nr, sh); return withFill(taga(ra), xf); } } - unit: - if (fill && nia>1) { - MAKE_MUT(m, nia); mut_init(m, selfElType(x)); - mut_setG(m, 0, x); - mut_fillG(m, 1, xf, nia-1); - Arr* ra = mut_fp(m); - arr_shSetU(ra, nr, sh); - return withFill(taga(ra), xf); - } + unit: if (isF64(x)) { decA(xf); i32 n = (i32)x.f; if (n == x.f) { @@ -495,7 +504,7 @@ B slash_c2(B t, B w, B x) { return c2(rt_slash, w, x); } -B slicev(B x, usz s, usz ia) { +static B slicev(B x, usz s, usz ia) { usz xia = a(x)->ia; assert(s+ia <= xia); Arr* r = TI(x,slice)(x, s, ia); arr_shVec(r); return taga(r); @@ -504,21 +513,60 @@ extern B rt_take, rt_drop; B take_c1(B t, B x) { return c1(rt_take, x); } B drop_c1(B t, B x) { return c1(rt_drop, x); } B take_c2(B t, B w, B x) { - if (isNum(w) && isArr(x) && rnk(x)==1) { - i64 v = o2i64(w); - usz ia = a(x)->ia; - u64 va = v<0? -v : v; - if (va>ia) { - B xf = getFillE(x); - MAKE_MUT(r, va); mut_init(r, TI(x,elType)); - mut_copyG(r, v<0? va-ia : 0, x, 0, ia); - mut_fill(r, v<0? 0 : ia, xf, va-ia); - dec(x); dec(xf); - return mut_fv(r); + if (isNum(w)) { + if (!isArr(x)) x = m_atomUnit(x); + i64 wv = o2i64(w); + ur xr = rnk(x); + usz csz = 1; + usz* xsh; + if (xr>1) { + csz = arr_csz(x); + xsh = a(x)->sh; + ptr_inc(shObjS(xsh)); // we'll look at it at the end and dec there } - if (v<0) return slicev(x, ia+v, -v); - else return slicev(x, 0, v); + i64 t = wv*csz; // TODO error on overflow somehow + Arr* a; + if (t>=0) { + a = take_impl(t, x); + } else { + t = -t; + usz xia = a(x)->ia; + if (t>xia) { + B xf = getFillE(x); + MAKE_MUT(r, t); mut_init(r, TI(x,elType)); + mut_fill(r, 0, xf, t-xia); + mut_copyG(r, t-xia, x, 0, xia); + dec(x); dec(xf); + a = mut_fp(r); + } else { + a = TI(x,slice)(x,xia-t,t); + } + } + if (xr<=1) { + arr_shVec(a); + } else { + usz* rsh = arr_shAlloc(a, xr); // xr>1, don't have to worry about 0 + rsh[0] = wv<0?-wv:wv; + for (i32 i = 1; i < xr; i++) rsh[i] = xsh[i]; + ptr_dec(shObjS(xsh)); + } + return taga(a); } + // if (isNum(w) && isArr(x) && rnk(x)==1) { + // i64 v = o2i64(w); + // usz ia = a(x)->ia; + // u64 va = v<0? -v : v; + // if (va>ia) { + // B xf = getFillE(x); + // MAKE_MUT(r, va); mut_init(r, TI(x,elType)); + // mut_copyG(r, v<0? va-ia : 0, x, 0, ia); + // mut_fill(r, v<0? 0 : ia, xf, va-ia); + // dec(x); dec(xf); + // return mut_fv(r); + // } + // if (v<0) return slicev(x, ia+v, -v); + // else return slicev(x, 0, v); + // } return c2(rt_take, w, x); } B drop_c2(B t, B w, B x) { diff --git a/src/core/stuff.h b/src/core/stuff.h index 8c9bbdc9..cdc83438 100644 --- a/src/core/stuff.h +++ b/src/core/stuff.h @@ -16,6 +16,7 @@ typedef struct ShArr { struct Value; usz a[]; } ShArr; +static ShArr* shObjS(usz* x) { return RFLD(x, ShArr, a); } static ShArr* shObj (B x) { return RFLD(a(x)->sh, ShArr, a); } static ShArr* shObjP(Value* x) { return RFLD(((Arr*)x)->sh, ShArr, a); } static void decSh(Value* x) { if (RARE(prnk(x)>1)) ptr_dec(shObjP(x));}