unified filling for ⥊ and ↑

well except for ¯N↑
This commit is contained in:
dzaima 2021-08-23 02:28:51 +03:00
parent 68214dda53
commit 63cafe8e7b
2 changed files with 76 additions and 27 deletions

View File

@ -4,6 +4,19 @@
#include "../utils/builtins.h"
#include "../utils/talloc.h"
static Arr* take_impl(usz ria, B x) { // consumes x; returns v↑⥊𝕩 without set shape; v is non-negative
usz xia = a(x)->ia;
if (ria>xia) {
B xf = getFillE(x);
MAKE_MUT(r, ria); mut_init(r, TI(x,elType));
mut_copyG(r, 0, x, 0, xia);
mut_fill(r, xia, xf, ria-xia);
dec(x); dec(xf);
return mut_fp(r);
} else {
return TI(x,slice)(x,0,ria);
}
}
B shape_c1(B t, B x) {
if (isAtm(x)) {
@ -48,7 +61,6 @@ B shape_c1(B t, B x) {
B shape_c2(B t, B w, B x) {
usz xia = isArr(x)? a(x)->ia : 1;
usz nia;
bool fill = false;
ur nr;
ShArr* sh;
if (isF64(w)) {
@ -96,6 +108,7 @@ B shape_c2(B t, B w, B x) {
i64 div = xia/tot;
i64 mod = xia%tot;
usz item;
bool fill = false;
if (unkInd == 52) {
if (mod!=0) thrM("⥊: Shape must be exact when reshaping with ∘");
item = div;
@ -111,6 +124,12 @@ B shape_c2(B t, B w, B x) {
tot*= item;
if (tot > USZ_MAX) thrM("⥊: Result too large");
nia = tot;
if (fill) {
Arr* a = take_impl(nia, x);
arr_shVec(a);
x = taga(a);
xia = nia;
}
} else nia = tot;
}
}
@ -147,25 +166,15 @@ B shape_c2(B t, B w, B x) {
i64 div = nia/xia;
i64 mod = nia%xia;
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia);
if (fill && mod && noFill(xf)) thrM("⥊: 𝕩 had no fill element");
if (fill) mut_fill(m, div*xia, xf, mod);
else mut_copyG(m, div*xia, x, 0, mod);
mut_copyG(m, div*xia, x, 0, mod);
dec(x);
Arr* ra = mut_fp(m);
arr_shSetU(ra, nr, sh);
return withFill(taga(ra), xf);
}
}
unit:
if (fill && nia>1) {
MAKE_MUT(m, nia); mut_init(m, selfElType(x));
mut_setG(m, 0, x);
mut_fillG(m, 1, xf, nia-1);
Arr* ra = mut_fp(m);
arr_shSetU(ra, nr, sh);
return withFill(taga(ra), xf);
}
unit:
if (isF64(x)) { decA(xf);
i32 n = (i32)x.f;
if (n == x.f) {
@ -495,7 +504,7 @@ B slash_c2(B t, B w, B x) {
return c2(rt_slash, w, x);
}
B slicev(B x, usz s, usz ia) {
static B slicev(B x, usz s, usz ia) {
usz xia = a(x)->ia; assert(s+ia <= xia);
Arr* r = TI(x,slice)(x, s, ia); arr_shVec(r);
return taga(r);
@ -504,21 +513,60 @@ extern B rt_take, rt_drop;
B take_c1(B t, B x) { return c1(rt_take, x); }
B drop_c1(B t, B x) { return c1(rt_drop, x); }
B take_c2(B t, B w, B x) {
if (isNum(w) && isArr(x) && rnk(x)==1) {
i64 v = o2i64(w);
usz ia = a(x)->ia;
u64 va = v<0? -v : v;
if (va>ia) {
B xf = getFillE(x);
MAKE_MUT(r, va); mut_init(r, TI(x,elType));
mut_copyG(r, v<0? va-ia : 0, x, 0, ia);
mut_fill(r, v<0? 0 : ia, xf, va-ia);
dec(x); dec(xf);
return mut_fv(r);
if (isNum(w)) {
if (!isArr(x)) x = m_atomUnit(x);
i64 wv = o2i64(w);
ur xr = rnk(x);
usz csz = 1;
usz* xsh;
if (xr>1) {
csz = arr_csz(x);
xsh = a(x)->sh;
ptr_inc(shObjS(xsh)); // we'll look at it at the end and dec there
}
if (v<0) return slicev(x, ia+v, -v);
else return slicev(x, 0, v);
i64 t = wv*csz; // TODO error on overflow somehow
Arr* a;
if (t>=0) {
a = take_impl(t, x);
} else {
t = -t;
usz xia = a(x)->ia;
if (t>xia) {
B xf = getFillE(x);
MAKE_MUT(r, t); mut_init(r, TI(x,elType));
mut_fill(r, 0, xf, t-xia);
mut_copyG(r, t-xia, x, 0, xia);
dec(x); dec(xf);
a = mut_fp(r);
} else {
a = TI(x,slice)(x,xia-t,t);
}
}
if (xr<=1) {
arr_shVec(a);
} else {
usz* rsh = arr_shAlloc(a, xr); // xr>1, don't have to worry about 0
rsh[0] = wv<0?-wv:wv;
for (i32 i = 1; i < xr; i++) rsh[i] = xsh[i];
ptr_dec(shObjS(xsh));
}
return taga(a);
}
// if (isNum(w) && isArr(x) && rnk(x)==1) {
// i64 v = o2i64(w);
// usz ia = a(x)->ia;
// u64 va = v<0? -v : v;
// if (va>ia) {
// B xf = getFillE(x);
// MAKE_MUT(r, va); mut_init(r, TI(x,elType));
// mut_copyG(r, v<0? va-ia : 0, x, 0, ia);
// mut_fill(r, v<0? 0 : ia, xf, va-ia);
// dec(x); dec(xf);
// return mut_fv(r);
// }
// if (v<0) return slicev(x, ia+v, -v);
// else return slicev(x, 0, v);
// }
return c2(rt_take, w, x);
}
B drop_c2(B t, B w, B x) {

View File

@ -16,6 +16,7 @@ typedef struct ShArr {
struct Value;
usz a[];
} ShArr;
static ShArr* shObjS(usz* x) { return RFLD(x, ShArr, a); }
static ShArr* shObj (B x) { return RFLD(a(x)->sh, ShArr, a); }
static ShArr* shObjP(Value* x) { return RFLD(((Arr*)x)->sh, ShArr, a); }
static void decSh(Value* x) { if (RARE(prnk(x)>1)) ptr_dec(shObjP(x));}