From 3505e1515e8005e3c2794d00c396ea2c063d3967 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 16 Sep 2022 11:42:55 -0400 Subject: [PATCH] =?UTF-8?q?Extend=20Group=20where=20=F0=9D=95=A8=20is=20an?= =?UTF-8?q?=20integer=20list=20to=20handle=20higher-rank=20=F0=9D=95=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/group.c | 162 +++++++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 69 deletions(-) diff --git a/src/builtins/group.c b/src/builtins/group.c index 7cca8b1a..aa84021c 100644 --- a/src/builtins/group.c +++ b/src/builtins/group.c @@ -3,6 +3,7 @@ #include "../utils/mut.h" #include "../builtins.h" +extern B ud_c1(B, B); extern B ne_c2(B, B, B); extern B slash_c1(B, B); extern B slash_c2(B, B, B); @@ -11,11 +12,29 @@ extern B take_c2(B, B, B); extern B drop_c2(B, B, B); extern B join_c2(B, B, B); +static Arr* arr_shChangeLen(Arr* a, ur r, usz* xsh, usz len) { + assert(r > 1); + usz* sh = a->sh = m_shArr(r)->a; + SPRNK(a,r); + sh[0] = len; + shcpy(sh+1, xsh+1, r-1); + return a; +} +static B m_shChangeLen(u8 xt, ur xr, usz* xsh, usz l, usz cw, usz csz) { + return taga(arr_shChangeLen(m_arr(offsetof(TyArr, a)+l*cw, xt, l*csz), xr, xsh, l)); +} +static void allocGroups(B* rp, usz ria, B z, u8 xt, ur xr, usz* xsh, i32* len, usz width, usz csz) { + if (xr==1) for (usz j = 0; j < ria; j++) { usz l=len[j]; if (!l) rp[j] = inc(z); else m_tyarrv(rp+j, width, l, xt); } + else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? inc(z) : m_shChangeLen(xt, xr, xsh, l, width, csz); } +} + extern B rt_group; B group_c2(B t, B w, B x) { - if (isArr(w)&isArr(x) && RNK(w)==1 && RNK(x)==1 && depth(w)==1) { + ur xr = RNK(x); + if (isArr(w)&isArr(x) && RNK(w)==1 && xr>=1 && depth(w)==1) { usz wia = IA(w); - usz xia = IA(x); + usz* xsh = SH(x); + usz xia = *xsh; if (wia-xia > 1) thrF("โŠ”: โ‰ ๐•จ must be either โ‰ ๐•ฉ or one bigger (%sโ‰กโ‰ ๐•จ, %sโ‰กโ‰ ๐•ฉ)", wia, xia); u8 we = TI(w,elType); if (elInt(we)) { @@ -48,7 +67,8 @@ B group_c2(B t, B w, B x) { for (usz i = 0; i < ria; i++) rp[i] = m_f64(0); // don't break if allocation errors B xf = getFillQ(x); - Arr* rf = arr_shVec(m_fillarrp(0)); fillarr_setFill(rf, m_f64(0)); + Arr* rf = m_fillarrp(0); if (xr==1) arr_shVec(rf); else arr_shChangeLen(rf, xr, xsh, 0); + fillarr_setFill(rf, m_f64(0)); B z = taga(rf); fillarr_setFill(r, z); @@ -60,10 +80,19 @@ B group_c2(B t, B w, B x) { goto intvec_ret; } - u8 xe = TI(x,elType); - u8 width = elWidth(xe); - u64 xw; - if (xia>64 && (xw=(u64)xia*width)<=I32_MAX && change1)) { + width *= csz = arr_csz(x); + xl += CTZ(csz); + if (bits && xl>=3) { bits=1; width>>=3; } + if ((csz & (csz-1)) || xl>7) xl = 7; + } + if (xia>64 && notB && !bits && change<(xia*width)/32) { #define C1(F,X ) F##_c1(m_f64(0),X ) #define C2(F,X,W) F##_c2(m_f64(0),X,W) if (wia>xia) w = C2(take, m_f64(xia), w); @@ -86,86 +115,83 @@ B group_c2(B t, B w, B x) { for (usz i = 0; i < wia; i++) len[wp[i]]+=ip[i]; void* xp = tyany_ptr(x); - u8 xt = el2t(xe); - for (usz j = 0; j < ria; j++) { - usz l=len[j]; - if (!l) rp[j]=inc(z); else m_tyarrv(rp+j, width, l, xt); - } - for (usz i=0, k=i0*width; i32 && neg>xia/4+xia/8) { - if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w); - B m = ne_c2(m_f64(0), m_f64(-1), inc(w)); - w = slash_c2(m_f64(0), inc(m), w); - x = slash_c2(m_f64(0), m, x); xia = IA(x); - neg = 0; - } - if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w)); - i32* wp = i32any_ptr(w); - for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0; - for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check - - switch (xe) { default: UD; - case el_i8: case el_c8: - case el_i16: case el_c16: - case el_i32: case el_c32: case el_f64: { - void* xp = tyany_ptr(x); - u8 xt = el2t(xe); - if (sort) { - for (usz j=0, i=neg*width; j=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break; - case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break; - case 4: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break; - case 8: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((f64*)tyarr_ptr(rp[n]))[pos[n]++] = ((f64*)xp)[i]; } break; - } - break; + decG(ind); + } else { + if (xia>32 && neg>xia/4+xia/8) { + if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w); + B m = ne_c2(m_f64(0), m_f64(-1), inc(w)); + w = slash_c2(m_f64(0), inc(m), w); + x = slash_c2(m_f64(0), m, x); xia = IA(x); + neg = 0; } - case el_bit: case el_B: { + if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w)); + i32* wp = i32any_ptr(w); + for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0; + for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check + + u8 xk = xl - 3; + if (notB && !bits && sort) { + void* xp = tyany_ptr(x); + u64 i=neg*width; for (usz j=0; j=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break; + case 1: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break; + case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break; + case 3: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((f64*)tyarr_ptr(rp[n]))[pos[n]++] = ((f64*)xp)[i]; } break; + } + } else { for (usz i = 0; i < ria; i++) { - Arr* c = m_fillarrp(len[i]); + usz l = len[i]; + Arr* c = m_fillarrp(l*csz); c->ia = 0; fillarr_setFill(c, inc(xf)); - arr_shVec(c); + if (xr==1) arr_shVec(c); else arr_shChangeLen(c, xr, xsh, l); rp[i] = taga(c); } SLOW2("๐•จโŠ”๐•ฉ", w, x); SGet(x) - for (usz i = 0; i < xia; i++) { - i32 n = wp[i]; - if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i); + if (csz == 1) { + for (usz i = 0; i < xia; i++) { + i32 n = wp[i]; + if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i); + } + } else { + for (usz i = 0; i < xia; i++) { + i32 n = wp[i]; + if (n<0) continue; + usz p = (pos[n]++)*csz; + B* rnp = fillarr_ptr(a(rp[n])) + p; + for (usz j = 0; j < csz; j++) rnp[j] = Get(x, i*csz + j); + } } - for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]; - break; + for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz; } } intvec_ret: fillarr_setFill(rf, xf); decG(w); decG(x); TFREE(pos); return taga(r); - } else { + } else if (xr==1) { SLOW2("๐•จโŠ”๐•ฉ", w, x); SGetU(w) i64 ria = wia==xia? 0 : o2i64(GetU(w, xia)); @@ -213,7 +239,6 @@ B group_c2(B t, B w, B x) { base: return c2(rt_group, w, x); } -B ud_c1(B, B); B group_c1(B t, B x) { if (isArr(x) && RNK(x)==1 && TI(x,arrD1)) { usz ia = IA(x); @@ -222,4 +247,3 @@ B group_c1(B t, B x) { } return c1(rt_group, x); } -