diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index d994d4b9..7f7e1f54 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -205,8 +205,17 @@ static void shape_c2_prim0(B c) { } #define SHAPE_C2_PRIM1(ID, GET) if (ID!=n_atop & ID!=n_floor & ID!=n_reverse & ID!=n_take) thrF("𝕨⥊𝕩: 𝕨 must consist of natural numbers or ∘ ⌊ ⌽ ↑ (contained %B)", GET) -B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh); -B shape_c2(B t, B w, B x); +Arr* reshape_cycle(usz nia, usz xia, B x); +SHOULD_INLINE Arr* reshape_unshaped(usz nia, B x) { + if (isArr(x)) { + usz xia = IA(x); + if (nia <= xia) return take_head(nia, x); + else return reshape_cycle(nia, xia, x); + } else { + return reshape_one(nia, x); + } +} + B shape_c2_01(usz wia, B w, B x) { switch (wia) { default: UD; case 0: // ⟨⟩⥊x @@ -214,114 +223,102 @@ B shape_c2_01(usz wia, B w, B x) { if (isAtm(x)) return m_unit(x); if (RARE(IA(x) == 0)) thrM("𝕨⥊𝕩: Empty 𝕩 and non-empty result"); return taga(arr_rnk01(take_impl(1, x), 0)); + case 1: // ⟨x⟩⥊1 w = TO_GET(w,0); // fallthrough + if (q_usz(w)) { + return taga(arr_shVec(reshape_unshaped(o2sG(w), x))); + } case 2: // atom - if (!q_usz(w)) { - shape_c2_prim0(w); - u8 id = RTID(w); - SHAPE_C2_PRIM1(id, w); - decG(w); - return C1(shape, x); - } - return C2(shape, w, x); + shape_c2_prim0(w); + u8 id = RTID(w); + SHAPE_C2_PRIM1(id, w); + decG(w); + return C1(shape, x); } } +B shape_c2_listw(B t, B w, B x); B shape_c2(B t, B w, B x) { - usz nia = 1; - ur nr; - ShArr* sh; if (q_usz(w)) { - nia = o2sG(w); - nr = 1; - sh = NULL; + return taga(arr_shVec(reshape_unshaped(o2sG(w), x))); } else { - if (RARE(isAtm(w))) return shape_c2_01(2, w, x); - if (RNK(w) > 1) thrF("𝕨⥊𝕩: 𝕨 must be a list or unit (%i ≡ =𝕩)", RNK(w)); - usz wia = IA(w); - if (wia <= 1) return shape_c2_01(wia, w, x); - if (wia > UR_MAX) thrF("𝕨⥊𝕩: Result rank too large (%i ≡ ≠𝕨)", wia); - nr = wia; - sh = m_shArr(nr); - - SGetU(w) - i32 unkPos = -1; - i32 unkID ONLY_GCC(=0); - usz xia ONLY_GCC(=0); - bool bad=false, good=false; - for (i32 i = 0; i < nr; i++) { - B c = GetU(w, i); - if (q_usz(c)) { - usz v = o2sG(c); - sh->a[i] = v; - if (RARE(mulOn(nia, v))) bad = true; - good|= v==0; - } else { - shape_c2_prim0(c); - if (unkPos!=-1) thrM("𝕨⥊𝕩: 𝕨 contained multiple computed axes"); - unkPos = i; - unkID = RTID(c); - xia = isArr(x)? IA(x) : 1; - good|= xia==0 | unkID==n_floor; - } - } - if (bad && !good) thrM("𝕨⥊𝕩: 𝕨 too large"); - - if (unkPos!=-1) { - SHAPE_C2_PRIM1(unkID, GetU(w,unkPos)); - if (nia==0) thrM("𝕨⥊𝕩: Can't compute axis when the rest of the shape is empty"); - usz div = xia/nia; - usz mod = xia%nia; - usz item; - bool fill = false; - if (unkID == n_atop) { - if (mod!=0) thrF("𝕨⥊𝕩: Shape must be exact when reshaping with ∘ (%H ≡ ≢𝕩, %s is the product of non-computed axis)", x, nia); - item = div; - } else if (unkID == n_floor) { - item = div; - } else if (unkID == n_reverse) { - item = mod? div+1 : div; - } else if (unkID == n_take) { - item = mod? div+1 : div; - fill = true; - } else UD; - sh->a[unkPos] = item; - nia = uszMul(nia, item); - if (fill) { - if (!isArr(x)) x = m_unit(x); - x = taga(arr_shVec(take_impl(nia, x))); - decG(w); - return truncReshape(x, nia, nia, nr, sh); // could be improved - } - } - decG(w); + return shape_c2_listw(t, w, x); } +} +NOINLINE B shape_c2_listw(B t, B w, B x) { + if (RARE(isAtm(w))) return shape_c2_01(2, w, x); + if (RNK(w) > 1) thrF("𝕨⥊𝕩: 𝕨 must be a list or unit (%i ≡ =𝕩)", RNK(w)); + usz wia = IA(w); + if (wia <= 1) return shape_c2_01(wia, w, x); + if (wia > UR_MAX) thrF("𝕨⥊𝕩: Result rank too large (%i ≡ ≠𝕨)", wia); - Arr* r; - if (isArr(x)) { - usz xia = IA(x); - if (nia <= xia) { - return truncReshape(x, xia, nia, nr, sh); + usz nia = 1; + ur nr = wia; + ShArr* sh = m_shArr(nr); + + SGetU(w) + i32 unkPos = -1; + i32 unkID ONLY_GCC(=0); + usz xia ONLY_GCC(=0); + bool bad=false, good=false; + for (i32 i = 0; i < nr; i++) { + B c = GetU(w, i); + if (q_usz(c)) { + usz v = o2sG(c); + sh->a[i] = v; + if (RARE(mulOn(nia, v))) bad = true; + good|= v==0; } else { - return reshape_cycle(nia, xia, x, nr, sh); + shape_c2_prim0(c); + if (unkPos!=-1) thrM("𝕨⥊𝕩: 𝕨 contained multiple computed axes"); + unkPos = i; + unkID = RTID(c); + xia = isArr(x)? IA(x) : 1; + good|= xia==0 | unkID==n_floor; } - } else { - r = reshape_one(nia, x); } - return taga(arr_shSetUO(r,nr,sh)); + if (bad && !good) thrM("𝕨⥊𝕩: 𝕨 too large"); + + if (unkPos!=-1) { + SHAPE_C2_PRIM1(unkID, GetU(w,unkPos)); + if (nia==0) thrM("𝕨⥊𝕩: Can't compute axis when the rest of the shape is empty"); + usz div = xia/nia; + usz mod = xia%nia; + usz item; + bool fill = false; + if (unkID == n_atop) { + if (mod!=0) thrF("𝕨⥊𝕩: Shape must be exact when reshaping with ∘ (%H ≡ ≢𝕩, %s is the product of non-computed axis)", x, nia); + item = div; + } else if (unkID == n_floor) { + item = div; + } else if (unkID == n_reverse) { + item = mod? div+1 : div; + } else if (unkID == n_take) { + item = mod? div+1 : div; + fill = true; + } else UD; + sh->a[unkPos] = item; + nia = uszMul(nia, item); + if (fill) { + if (!isArr(x)) x = m_unit(x); + x = taga(arr_shVec(take_impl(nia, x))); + decG(w); + return truncReshape(x, nia, nia, nr, sh); // could be improved + } + } + decG(w); + return taga(arr_shSetUO(reshape_unshaped(nia, x), nr, sh)); } -B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh) { +Arr* reshape_cycle(usz nia, usz xia, B x) { assert(nia > xia); Arr* r; if (xia <= 1) { if (RARE(xia == 0)) thrM("𝕨⥊𝕩: Empty 𝕩 and non-empty result"); x = TO_GET(x, 0); - - r = reshape_one(nia, x); - return taga(arr_shSetUO(r,nr,sh)); + return reshape_one(nia, x); } if (xia <= nia/2) x = squeeze_any(x); @@ -361,14 +358,14 @@ B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh) { if (bi == 1) { memset(rp, rp[0], bf); bi=bf; } } else { if (TI(x,elType) == el_B) { - MAKE_MUT_INIT(m, nia, el_B); MUTG_INIT(m); + UntaggedArr r = m_barrp_fill(x, nia); i64 div = nia/xia; i64 mod = nia%xia; - for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia); - mut_copyG(m, div*xia, x, 0, mod); - B xf = getFillR(x); + for (i64 i = 0; i < div; i++) COPY_TO(r.data, el_B, i*xia, x, 0, xia); + COPY_TO(r.data, el_B, div*xia, x, 0, mod); + NOGC_E; decG(x); - return withFill(taga(arr_shSetUO(mut_fp(m), nr, sh)), xf); + return r.obj; } u8 xk = xl - 3; if (nia >= USZ_MAX) thrOOM(); @@ -392,7 +389,7 @@ B reshape_cycle(usz nia, usz xia, B x, ur nr, ShArr* sh) { u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi); if (e