more 𝕨⥊𝕩 refactoring

This commit is contained in:
dzaima 2025-05-29 23:29:37 +03:00
parent dd7c21ed86
commit c0cb1a9f77

View File

@ -205,8 +205,17 @@ static void shape_c2_prim0(B c) {
} }
#define SHAPE_C2_PRIM1(ID, GET) if (ID!=n_atop & ID!=n_floor & ID!=n_reverse & ID!=n_take) thrF("𝕨⥊𝕩: 𝕨 must consist of natural numbers or ∘ ⌊ ⌽ ↑ (contained %B)", GET) #define SHAPE_C2_PRIM1(ID, GET) if (ID!=n_atop & ID!=n_floor & ID!=n_reverse & ID!=n_take) thrF("𝕨⥊𝕩: 𝕨 must consist of natural numbers or ∘ ⌊ ⌽ ↑ (contained %B)", GET)
B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh); Arr* reshape_cycle(usz nia, usz xia, B x);
B shape_c2(B t, B w, B x); SHOULD_INLINE Arr* reshape_unshaped(usz nia, B x) {
if (isArr(x)) {
usz xia = IA(x);
if (nia <= xia) return take_head(nia, x);
else return reshape_cycle(nia, xia, x);
} else {
return reshape_one(nia, x);
}
}
B shape_c2_01(usz wia, B w, B x) { B shape_c2_01(usz wia, B w, B x) {
switch (wia) { default: UD; switch (wia) { default: UD;
case 0: // ⟨⟩⥊x case 0: // ⟨⟩⥊x
@ -214,114 +223,102 @@ B shape_c2_01(usz wia, B w, B x) {
if (isAtm(x)) return m_unit(x); if (isAtm(x)) return m_unit(x);
if (RARE(IA(x) == 0)) thrM("𝕨⥊𝕩: Empty 𝕩 and non-empty result"); if (RARE(IA(x) == 0)) thrM("𝕨⥊𝕩: Empty 𝕩 and non-empty result");
return taga(arr_rnk01(take_impl(1, x), 0)); return taga(arr_rnk01(take_impl(1, x), 0));
case 1: // ⟨x⟩⥊1 case 1: // ⟨x⟩⥊1
w = TO_GET(w,0); w = TO_GET(w,0);
// fallthrough // fallthrough
if (q_usz(w)) {
return taga(arr_shVec(reshape_unshaped(o2sG(w), x)));
}
case 2: // atom case 2: // atom
if (!q_usz(w)) { shape_c2_prim0(w);
shape_c2_prim0(w); u8 id = RTID(w);
u8 id = RTID(w); SHAPE_C2_PRIM1(id, w);
SHAPE_C2_PRIM1(id, w); decG(w);
decG(w); return C1(shape, x);
return C1(shape, x);
}
return C2(shape, w, x);
} }
} }
B shape_c2_listw(B t, B w, B x);
B shape_c2(B t, B w, B x) { B shape_c2(B t, B w, B x) {
usz nia = 1;
ur nr;
ShArr* sh;
if (q_usz(w)) { if (q_usz(w)) {
nia = o2sG(w); return taga(arr_shVec(reshape_unshaped(o2sG(w), x)));
nr = 1;
sh = NULL;
} else { } else {
if (RARE(isAtm(w))) return shape_c2_01(2, w, x); return shape_c2_listw(t, w, x);
if (RNK(w) > 1) thrF("𝕨⥊𝕩: 𝕨 must be a list or unit (%i ≡ =𝕩)", RNK(w));
usz wia = IA(w);
if (wia <= 1) return shape_c2_01(wia, w, x);
if (wia > UR_MAX) thrF("𝕨⥊𝕩: Result rank too large (%i ≡ ≠𝕨)", wia);
nr = wia;
sh = m_shArr(nr);
SGetU(w)
i32 unkPos = -1;
i32 unkID ONLY_GCC(=0);
usz xia ONLY_GCC(=0);
bool bad=false, good=false;
for (i32 i = 0; i < nr; i++) {
B c = GetU(w, i);
if (q_usz(c)) {
usz v = o2sG(c);
sh->a[i] = v;
if (RARE(mulOn(nia, v))) bad = true;
good|= v==0;
} else {
shape_c2_prim0(c);
if (unkPos!=-1) thrM("𝕨⥊𝕩: 𝕨 contained multiple computed axes");
unkPos = i;
unkID = RTID(c);
xia = isArr(x)? IA(x) : 1;
good|= xia==0 | unkID==n_floor;
}
}
if (bad && !good) thrM("𝕨⥊𝕩: 𝕨 too large");
if (unkPos!=-1) {
SHAPE_C2_PRIM1(unkID, GetU(w,unkPos));
if (nia==0) thrM("𝕨⥊𝕩: Can't compute axis when the rest of the shape is empty");
usz div = xia/nia;
usz mod = xia%nia;
usz item;
bool fill = false;
if (unkID == n_atop) {
if (mod!=0) thrF("𝕨⥊𝕩: Shape must be exact when reshaping with ∘ (%H ≡ ≢𝕩, %s is the product of non-computed axis)", x, nia);
item = div;
} else if (unkID == n_floor) {
item = div;
} else if (unkID == n_reverse) {
item = mod? div+1 : div;
} else if (unkID == n_take) {
item = mod? div+1 : div;
fill = true;
} else UD;
sh->a[unkPos] = item;
nia = uszMul(nia, item);
if (fill) {
if (!isArr(x)) x = m_unit(x);
x = taga(arr_shVec(take_impl(nia, x)));
decG(w);
return truncReshape(x, nia, nia, nr, sh); // could be improved
}
}
decG(w);
} }
}
NOINLINE B shape_c2_listw(B t, B w, B x) {
if (RARE(isAtm(w))) return shape_c2_01(2, w, x);
if (RNK(w) > 1) thrF("𝕨⥊𝕩: 𝕨 must be a list or unit (%i ≡ =𝕩)", RNK(w));
usz wia = IA(w);
if (wia <= 1) return shape_c2_01(wia, w, x);
if (wia > UR_MAX) thrF("𝕨⥊𝕩: Result rank too large (%i ≡ ≠𝕨)", wia);
Arr* r; usz nia = 1;
if (isArr(x)) { ur nr = wia;
usz xia = IA(x); ShArr* sh = m_shArr(nr);
if (nia <= xia) {
return truncReshape(x, xia, nia, nr, sh); SGetU(w)
i32 unkPos = -1;
i32 unkID ONLY_GCC(=0);
usz xia ONLY_GCC(=0);
bool bad=false, good=false;
for (i32 i = 0; i < nr; i++) {
B c = GetU(w, i);
if (q_usz(c)) {
usz v = o2sG(c);
sh->a[i] = v;
if (RARE(mulOn(nia, v))) bad = true;
good|= v==0;
} else { } else {
return reshape_cycle(nia, xia, x, nr, sh); shape_c2_prim0(c);
if (unkPos!=-1) thrM("𝕨⥊𝕩: 𝕨 contained multiple computed axes");
unkPos = i;
unkID = RTID(c);
xia = isArr(x)? IA(x) : 1;
good|= xia==0 | unkID==n_floor;
} }
} else {
r = reshape_one(nia, x);
} }
return taga(arr_shSetUO(r,nr,sh)); if (bad && !good) thrM("𝕨⥊𝕩: 𝕨 too large");
if (unkPos!=-1) {
SHAPE_C2_PRIM1(unkID, GetU(w,unkPos));
if (nia==0) thrM("𝕨⥊𝕩: Can't compute axis when the rest of the shape is empty");
usz div = xia/nia;
usz mod = xia%nia;
usz item;
bool fill = false;
if (unkID == n_atop) {
if (mod!=0) thrF("𝕨⥊𝕩: Shape must be exact when reshaping with ∘ (%H ≡ ≢𝕩, %s is the product of non-computed axis)", x, nia);
item = div;
} else if (unkID == n_floor) {
item = div;
} else if (unkID == n_reverse) {
item = mod? div+1 : div;
} else if (unkID == n_take) {
item = mod? div+1 : div;
fill = true;
} else UD;
sh->a[unkPos] = item;
nia = uszMul(nia, item);
if (fill) {
if (!isArr(x)) x = m_unit(x);
x = taga(arr_shVec(take_impl(nia, x)));
decG(w);
return truncReshape(x, nia, nia, nr, sh); // could be improved
}
}
decG(w);
return taga(arr_shSetUO(reshape_unshaped(nia, x), nr, sh));
} }
B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh) { Arr* reshape_cycle(usz nia, usz xia, B x) {
assert(nia > xia); assert(nia > xia);
Arr* r; Arr* r;
if (xia <= 1) { if (xia <= 1) {
if (RARE(xia == 0)) thrM("𝕨⥊𝕩: Empty 𝕩 and non-empty result"); if (RARE(xia == 0)) thrM("𝕨⥊𝕩: Empty 𝕩 and non-empty result");
x = TO_GET(x, 0); x = TO_GET(x, 0);
return reshape_one(nia, x);
r = reshape_one(nia, x);
return taga(arr_shSetUO(r,nr,sh));
} }
if (xia <= nia/2) x = squeeze_any(x); if (xia <= nia/2) x = squeeze_any(x);
@ -361,14 +358,14 @@ B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh) {
if (bi == 1) { memset(rp, rp[0], bf); bi=bf; } if (bi == 1) { memset(rp, rp[0], bf); bi=bf; }
} else { } else {
if (TI(x,elType) == el_B) { if (TI(x,elType) == el_B) {
MAKE_MUT_INIT(m, nia, el_B); MUTG_INIT(m); UntaggedArr r = m_barrp_fill(x, nia);
i64 div = nia/xia; i64 div = nia/xia;
i64 mod = nia%xia; i64 mod = nia%xia;
for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia); for (i64 i = 0; i < div; i++) COPY_TO(r.data, el_B, i*xia, x, 0, xia);
mut_copyG(m, div*xia, x, 0, mod); COPY_TO(r.data, el_B, div*xia, x, 0, mod);
B xf = getFillR(x); NOGC_E;
decG(x); decG(x);
return withFill(taga(arr_shSetUO(mut_fp(m), nr, sh)), xf); return r.obj;
} }
u8 xk = xl - 3; u8 xk = xl - 3;
if (nia >= USZ_MAX) thrOOM(); if (nia >= USZ_MAX) thrOOM();
@ -392,7 +389,7 @@ B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh) {
u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi); u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi);
if (e<bf) memcpy(rp+e, rp, bf-e); if (e<bf) memcpy(rp+e, rp, bf-e);
} }
return taga(arr_shSetUO(r,nr,sh)); return r;
} }
B pick_c1(B t, B x) { B pick_c1(B t, B x) {