// Group and Group Indices (βŠ”) // Group Indices: calls π•©βŠ”β†•π•© for rank-1 flat 𝕩, otherwise self-hosted // Group: native code for rank-1 𝕨 only, optimizations for integers // SHOULD squeeze 𝕨 // All statistics computed in the initial pass that finds βŒˆΒ΄π•¨ // If 𝕨 is boolean, compute from π•¨Β¬βŠΈ/𝕩 and 𝕨/𝕩 // COULD handle small-range 𝕨 with equals-replicate // If +Β΄Β»βŠΈβ‰ π•¨ is small, process in chunks as a separate case // If +´𝕨<Β―1 is large, filter out Β―1s. // COULD recompute statistics, may have enabled chunked or sorted code // If ∧´1β†“Β»βŠΈ<𝕨, that is, βˆ§βŠΈβ‰‘π•¨, each result array is a slice of 𝕩 // COULD use slice types; seems dangerous--when will they be freed? // Remaining cases copy cells from 𝕩 individually // Converts 𝕨 to i32, COULD handle smaller types // CPU-sized cells handled quickly, 1-bit with bitp_get/set // SHOULD use memcpy and bit_cpy for other sizes // TRIED separating neg>0 and neg==0 loops, no effect #include "../core.h" #include "../utils/talloc.h" #include "../utils/calls.h" #include "../builtins.h" #include "../utils/mut.h" extern B ud_c1(B, B); extern B ne_c2(B, B, B); extern B slash_c1(B, B); extern B slash_c2(B, B, B); extern B select_c2(B, B, B); extern B take_c2(B, B, B); static Arr* arr_shChangeLen(Arr* a, ur r, usz* xsh, usz len) { assert(r > 1); usz* sh = a->sh = m_shArr(r)->a; SPRNK(a,r); sh[0] = len; shcpy(sh+1, xsh+1, r-1); return a; } static B m_shChangeLen(u8 xt, ur xr, usz* xsh, usz l, usz cw, usz csz) { return taga(arr_shChangeLen(m_arr(offsetof(TyArr, a)+l*cw, xt, l*csz), xr, xsh, l)); } static void allocGroups(B* rp, usz ria, B z, u8 xt, ur xr, usz* xsh, i32* len, usz width, usz csz) { if (xr==1) for (usz j = 0; j < ria; j++) { usz l=len[j]; if (!l) rp[j] = incG(z); else m_tyarrv(rp+j, width, l, xt); } else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? incG(z) : m_shChangeLen(xt, xr, xsh, l, width, csz); } } static Arr* m_bitarr_nop(usz ia) { return m_arr(BITARR_SZ(ia), t_bitarr, ia); } static void allocBitGroups(B* rp, usz ria, B z, ur xr, usz* xsh, i32* len, usz width) { if (xr==1) for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? incG(z) : taga(arr_shVec(m_bitarr_nop(l))); } else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? incG(z) : taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)); } } // Integer list w static B group_simple(B w, B x, ur xr, usz wia, usz xn, usz* xsh, u8 we) { i64 ria = 0; bool bad = false, sort = true; usz neg = 0, change = 0; void *wp0 = tyany_ptr(w); #define CASE(T) case el_##T: { \ T max = -1, prev = -1; \ for (usz i = 0; i < xn; i++) { \ T n = ((T*)wp0)[i]; \ if (n>max) max = n; \ bad |= n < -1; \ neg += n == -1; \ sort &= prev <= n; \ change += prev != n; \ prev = n; \ } \ if (wia>xn) { ria=((T*)wp0)[xn]; bad|=ria<-1; } \ i64 m=(i64)max+1; if (m>ria) ria=m; \ break; } switch (we) { default:UD; CASE(i8) CASE(i16) CASE(i32) // Boolean w is special-cased before we would check sort or change case el_bit: ria = xn? 1+bit_has(wp0,xn,1) : wia? bitp_get(wp0,0) : 0; break; } #undef CASE if (bad) thrM("βŠ”: 𝕨 can't contain elements less than Β―1"); if (ria > (i64)(USZ_MAX)) thrOOM(); Arr* r = m_fillarr0p(ria); B* rp = fillarr_ptr(r); B xf = getFillQ(x); Arr* rf = m_fillarrpEmpty(xf); if (xr==1) arr_shVec(rf); else arr_shChangeLen(rf, xr, xsh, 0); B z = taga(rf); fillarr_setFill(r, z); if (ria <= 1) { if (ria == 0) goto dec_ret; // Needed so wia>0 if (neg == 0) { rp[0]=inc(x); goto dec_ret; } // else, 𝕨 is a mix of 0 and Β―1 (and maybe trailing 1) } if (we==el_bit) { assert(ria == 2); if (wia>xn) w = C2(take, m_f64(xn), w); rp[1] = C2(slash, inc(w), inc(x)); rp[0] = C2(slash, bit_negate(w), x); return taga(r); } // Needed to make sure wia>0 for ip[wia-1] below if (neg==xn) { FILL_TO(rp, el_B, 0, z, ria); goto dec_ret; } TALLOC(i32, pos, 2*ria+1); i32* len = pos+ria+1; bool notB = TI(x,elType) != el_B; u8 xt = arrNewType(TY(x)); u8 xl = arrTypeBitsLog(TY(x)); bool bits = xl == 0; u64 width = bits ? 1 : 1<<(xl-3); // cell width in bits if bits==1, bytes otherwise usz csz = 1; if (RARE(xr>1)) { width *= csz = arr_csz(x); usz cs = csz | (csz==0); xl += CTZ(cs); if (bits && xl>=3) { bits=0; width>>=3; } if ((cs & (cs-1)) || xl>7) xl = 7; } // Few changes in 𝕨: move in chunks if (xn>64 && notB && change<(xn*width)/32) { u64* mp; B m = m_bitarrv(&mp, xn); u8* wp0 = tyany_ptr(w); we = TI(w,elType); CMP_AA_IMM(ne, we, mp, wp0-elWidth(we), wp0, xn); bitp_set(mp, 0, -1!=o2fG(IGetU(w,0))); B ind = C1(slash, m); w = C2(select, incG(ind), w); ind = toI32Any(ind); w = toI32Any(w); wia = IA(ind); i32* ip = i32any_ptr(ind); i32* wp = i32any_ptr(w); usz i0 = ip[0]; for (usz i=0; i32 && neg>(bits?0:xn/4)+xn/8) { if (wia>xn) w = C2(take, m_f64(xn), w); B m = C2(ne, m_f64(-1), inc(w)); w = C2(slash, inc(m), w); x = C2(slash, m, x); xn = *SH(x); neg = 0; } w = toI32Any(w); i32* wp = i32any_ptr(w); for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0; for (usz i = 0; i < xn; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check u8 xk = xl - 3; if (notB && csz==0) { // Empty cells, no movement needed allocBitGroups(rp, ria, z, xr, xsh, len, width); } else if (notB && sort) { // Sorted 𝕨, that is, partition 𝕩 void* xp = tyany_ptr(x); u64 i=neg*width; #define GROUP_SORT(CPY, ALLOC) \ for (usz j=0; j=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break; case 1: for (usz i = 0; i < xn; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break; case 2: for (usz i = 0; i < xn; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break; case 3: for (usz i = 0; i < xn; i++) { i32 n = wp[i]; if (n>=0) ((u64*)tyarr_ptr(rp[n]))[pos[n]++] = ((u64*)xp)[i]; } break; } } else if (xl == 0) { // 1-bit cells u64* xp = bitarr_ptr(x); allocBitGroups(rp, ria, z, xr, xsh, len, width); for (usz i = 0; i < xn; i++) { bool b = bitp_get(xp,i); i32 n = wp[i]; if (n>=0) bitp_set(bitarr_ptr(rp[n]), pos[n]++, b); } } else { // Generic case for (usz i = 0; i < ria; i++) { usz l = len[i]; Arr* c = m_fillarrp(l*csz); c->ia = 0; fillarr_setFill(c, inc(xf)); NOGC_E; // ia=0 means that this is a "safe" array if (xr==1) arr_shVec(c); else arr_shChangeLen(c, xr, xsh, l); rp[i] = taga(c); } if (csz==0) goto done; SLOW2("π•¨βŠ”π•©", w, x); SGet(x) NOGC_S; // ia==0 of the elements means they're in a sort of invalid state; though it should be fine as `x` references all the items that may be in them, so this NOGC isn't strictly necessary if (csz == 1) { for (usz i = 0; i < xn; i++) { i32 n = wp[i]; if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i); } } else { for (usz i = 0; i < xn; i++) { i32 n = wp[i]; if (n<0) continue; usz p = (pos[n]++)*csz; B* rnp = fillarr_ptr(a(rp[n])) + p; for (usz j = 0; j < csz; j++) rnp[j] = Get(x, i*csz + j); } } for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz; NOGC_E; } done: TFREE(pos); dec_ret: decG(w); decG(x); return taga(r); } extern B rt_group; B group_c2(B t, B w, B x) { if (!isArr(x)) thrM("βŠ”: 𝕩 must be an array"); ur xr = RNK(x); if (isArr(w) && RNK(w)==1 && xr>=1 && depth(w)==1) { usz wia = IA(w); usz* xsh = SH(x); usz xn = *xsh; if (wia-xn > 1) thrF("βŠ”: ≠𝕨 must be either ≠𝕩 or one bigger (%s≑≠𝕨, %s≑≠𝕩)", wia, xn); u8 we = TI(w,elType); if (elInt(we)) { return group_simple(w, x, xr, wia, xn, xsh, we); } else if (xr==1) { SLOW2("π•¨βŠ”π•©", w, x); SGetU(w) i64 ria = wia==xn? 0 : o2i64(GetU(w, xn)); if (ria<0) { if (ria<-1) thrM("βŠ”: 𝕨 can't contain elements less than Β―1"); ria = 0; } ria--; for (usz i = 0; i < xn; i++) { B cw = GetU(w, i); if (!q_i64(cw)) goto base; i64 c = o2i64G(cw); if (c>ria) ria = c; if (c<-1) thrM("βŠ”: 𝕨 can't contain elements less than Β―1"); } if (ria > (i64)(USZ_MAX-1)) thrOOM(); ria++; TALLOC(i32, lenO, ria+1); i32* len = lenO+1; TALLOC(i32, pos, ria); for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0; for (usz i = 0; i < xn; i++) len[o2i64G(GetU(w, i))]++; Arr* r = m_fillarr0p(ria); B xf = getFillQ(x); B* rp = fillarr_ptr(r); for (usz i = 0; i < ria; i++) { Arr* c = m_fillarrp(len[i]); c->ia = 0; fillarr_setFill(c, inc(xf)); NOGC_E; // see comments in group_simple arr_shVec(c); rp[i] = taga(c); } fillarr_setFill(r, taga(arr_shVec(m_fillarrpEmpty(xf)))); SGet(x) NOGC_S; for (usz i = 0; i < xn; i++) { i64 n = o2i64G(GetU(w, i)); if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i); } for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]; NOGC_E; decG(w); decG(x); TFREE(lenO); TFREE(pos); return taga(r); } } base: return c2rt(group, w, x); } B group_c1(B t, B x) { if (isArr(x) && RNK(x)==1 && TI(x,arrD1)) { usz ia = IA(x); B range = C1(ud, m_f64(ia)); return C2(group, x, range); } return c1rt(group, x); }