shape replacement utility

This commit is contained in:
dzaima 2023-01-17 03:47:14 +02:00
parent bc4079b126
commit 540b37ae6a
7 changed files with 67 additions and 33 deletions

View File

@ -283,6 +283,7 @@ static NOINLINE B shift_cells(B f, B x, u8 e, u8 rtid) {
return mut_fcd(r, x);
}
B shape_c1(B, B);
B cell_c1(Md1D* d, B x) { B f = d->f;
if (isAtm(x) || RNK(x)==0) {
B r = c1(f, x);
@ -297,22 +298,24 @@ B cell_c1(Md1D* d, B x) { B f = d->f;
if (rtid==n_select && xr>1) return select_cells(0, x, xr);
if (rtid==n_pick && xr>1 && TI(x,arrD1)) return select_cells(0, x, xr);
if (rtid==n_couple) {
if (xr==0) return C1(shape, x);
Arr* r = cpyWithShape(x);
usz* xsh = PSH(r);
ShArr* rsh = m_shArr(xr+1);
usz* xsh = SH(x);
rsh->a[0] = xsh[0];
rsh->a[1] = 1;
shcpy(rsh->a+2, xsh+1, xr-1);
Arr* r = TI(x,slice)(x, 0, IA(x));
return taga(arr_shSetU(r, xr+1, rsh));
return taga(arr_shReplace(r, xr+1, rsh));
}
if (rtid==n_shape) {
usz cam = SH(x)[0];
usz csz = arr_csz(x);
Arr* ra = TI(x,slice)(x,0,IA(x));
usz* rsh = arr_shAlloc(ra, 2);
rsh[0] = cam;
rsh[1] = csz;
return taga(ra);
if (xr==2) return x;
Arr* r = cpyWithShape(x);
usz cam = PSH(r)[0];
usz csz = shProd(PSH(r), 1, xr);
ShArr* rsh = m_shArr(2);
rsh->a[0] = cam;
rsh->a[1] = csz;
return taga(arr_shReplace(r, 2, rsh));
}
if ((rtid==n_shifta || rtid==n_shiftb) && xr==2) {
B xf = getFillR(x);

View File

@ -269,6 +269,6 @@ B count_c2(B t, B w, B x) {
void search_init() {
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAlloc(a,0); *p= 0; gc_add(enclosed_0=taga(a)); }
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAlloc(a,0); *p=~0ULL; gc_add(enclosed_1=taga(a)); }
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAtm(a); *p= 0; gc_add(enclosed_0=taga(a)); }
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAtm(a); *p=~0ULL; gc_add(enclosed_1=taga(a)); }
}

View File

@ -16,10 +16,7 @@ static NOINLINE Arr* emptyArr(B x, ur xr) { // returns an empty array with the f
else if (noFill(xf)) { r = (Arr*) m_harrUp(0).c; }
else if (isC32(xf)) { u8* rp; r = m_c8arrp(&rp, 0); }
else { r = m_fillarrp(0); fillarr_setFill(r, xf); }
if (xr<=1) {
if (LIKELY(xr==1)) arr_shVec(r);
else arr_shAlloc(r, 0);
}
if (xr<=1) arr_rnk01(r, xr);
return r;
}
@ -101,6 +98,24 @@ B m_vec2(B a, B b) { return m_vec2Base(a, b, false); }
B pair_c1(B t, B x) { return m_vec1(x); }
B pair_c2(B t, B w, B x) { return m_vec2Base(w, x, true); }
Arr* cpyWithShape(B x) {
Arr* xv = a(x);
if (reusable(x)) return xv;
ur xr = PRNK(xv);
Arr* r;
if (xr<=1) {
r = TIv(xv,slice)(x, 0, PIA(xv));
arr_rnk01(r, xr);
} else {
usz* sh = PSH(xv);
ptr_inc(shObjS(sh));
r = TIv(xv,slice)(x, 0, PIA(xv));
r->sh = sh;
}
SPRNK(r, xr);
return r;
}
B shape_c1(B t, B x) {
if (isAtm(x)) return m_vec1(x);
if (RNK(x)==1) return x;

View File

@ -405,14 +405,12 @@ B rand_range_c2(B t, B w, B x) {
RAND_END;
if (isArr(w)) {
usz wia = IA(w);
switch (wia) {
case 0: { arr_shAlloc(r, 0); break; }
case 1: { arr_shVec(r); break; }
default: {
usz* sh = arr_shAlloc(r, wia);
SGetU(w);
for (usz i = 0; i < wia; i++) sh[i] = o2sG(GetU(w, i));
}
if (wia<2) {
arr_rnk01(r, wia);
} else {
usz* sh = arr_shAlloc(r, wia);
SGetU(w);
for (usz i = 0; i < wia; i++) sh[i] = o2sG(GetU(w, i));
}
} else {
arr_shVec(r);

View File

@ -166,7 +166,7 @@ NOINLINE B m_unit(B x) {
B xf = asFill(inc(x));
if (noFill(xf)) return m_hunit(x);
FillArr* r = m_arr(fsizeof(FillArr,a,B,1), t_fillarr, 1);
arr_shAlloc((Arr*)r, 0);
arr_shAtm((Arr*)r);
r->fill = xf;
r->a[0] = x;
return taga(r);
@ -190,6 +190,6 @@ NOINLINE B m_atomUnit(B x) {
TyArr* r = m_arr(offsetof(TyArr,a) + sizeof(u64), t, 1);
*((u64*)r->a) = data;
FINISH_OVERALLOC(r, offsetof(TyArr,a)+sz, offsetof(TyArr,a)+sizeof(u64));
arr_shAlloc((Arr*)r, 0);
arr_shAtm((Arr*)r);
return taga(r);
}

View File

@ -96,7 +96,7 @@ static HArr_p m_harrUp(usz ia) {
static B m_hunit(B x) { // consumes
HArr_p r = m_harrUp(1);
arr_shAlloc((Arr*)r.c, 0);
arr_shAtm((Arr*)r.c);
r.a[0] = x;
return r.b;
}

View File

@ -47,21 +47,24 @@ static ShArr* m_shArr(ur r) {
return ((ShArr*)mm_alloc(fsizeof(ShArr, a, usz, r), t_shape));
}
static Arr* arr_shVec(Arr* x) {
SPRNK(x, 1);
FORCE_INLINE Arr* arr_rnk01(Arr* x, ur xr) {
SPRNK(x, xr);
x->sh = &x->ia;
return x;
}
static Arr* arr_shAtm(Arr* x) { return arr_rnk01(x, 0); }
static Arr* arr_shVec(Arr* x) { return arr_rnk01(x, 1); }
static usz* arr_shAlloc(Arr* x, ur r) { // sets rank, allocates & returns shape (or null if r<2); assumes x has rank≤1 (which will be the case for new allocations)
assert(PRNK(x)<=1);
if (r>1) {
if (r<=1) {
arr_rnk01(x, r);
return NULL;
} else {
usz* sh = x->sh = m_shArr(r)->a; // if m_shArr fails, the assumed rank≤1 guarantees the uninitialized x->sh won't break
SPRNK(x,r);
return sh;
}
SPRNK(x,r);
x->sh = &x->ia;
return NULL;
}
static Arr* arr_shSetI(Arr* x, ur r, ShArr* sh) { // set rank and assign and increment shape if needed
SPRNK(x,r);
@ -75,6 +78,12 @@ static Arr* arr_shSetU(Arr* x, ur r, ShArr* sh) { // set rank and assign shape
else x->sh = &x->ia;
return x;
}
static Arr* arr_shSetUG(Arr* x, ur r, ShArr* sh) { // arr_shSetU but guaranteed r>1
assert(r>1);
SPRNK(x,r);
x->sh = sh->a;
return x;
}
static Arr* arr_shCopyUnchecked(Arr* n, B o) {
ur r = SPRNK(n,RNK(o));
if (r<=1) {
@ -86,6 +95,14 @@ static Arr* arr_shCopyUnchecked(Arr* n, B o) {
}
return n;
}
static Arr* arr_shReplace(Arr* x, ur r, ShArr* sh) { // replace x's shape with a new one
usz* prevsh = x->sh;
u8 xr = PRNK(x);
SPRNK(x, r);
x->sh = sh->a;
if (xr>1) decShObj(shObjS(prevsh));
return x;
}
static Arr* arr_shCopy(Arr* n, B o) { // copy shape & rank from o to n
assert(isArr(o));
assert(IA(o)==n->ia);
@ -123,6 +140,7 @@ B bit_sel(B b, B e0, B e1); // consumes b; b must be bitarr; b⊏e0‿e1
Arr* allZeroes(usz ia);
Arr* allOnes(usz ia);
B bit_negate(B x); // consumes
Arr* cpyWithShape(B x); // consumes; returns array with refcount 1 with the same shape as x; to allocate a new shape in its place, the previous one needs to be freed, rank set to 1, and then shape & rank set to the new ones
static B m_hVec1(B a ); // consumes all
static B m_hVec2(B a, B b ); // consumes all