Split out optimized Group cases and add comments

This commit is contained in:
Marshall Lochbaum 2022-11-09 17:51:25 -05:00
parent 801472d6d4
commit 6eb504118e

View File

@ -1,3 +1,22 @@
// Group and Group Indices (⊔)
// Group Indices: calls 𝕩⊔↕𝕩 for rank-1 flat 𝕩, otherwise self-hosted
// Group: native code for rank-1 𝕨 only, optimizations for integers
// SHOULD squeeze 𝕨
// SHOULD handle boolean 𝕨 with replicate
// COULD handle small-range 𝕨 with equals-replicate
// All statistics computed in the initial pass that finds ⌈´𝕨
// If +´»⊸≠𝕨 is small, process in chunks as a separate case
// If +´𝕨<¯1 is large, filter out ¯1s.
// COULD recompute statistics, may have enabled chunked or sorted code
// If ∧´1↓»⊸<𝕨, that is, ∧⊸≡𝕨, each result array is a slice of 𝕩
// COULD use slice types; seems dangerous--when will they be freed?
// Remaining cases copy cells from 𝕩 individually
// CPU-sized cells handled quickly
// SHOULD use bit ops for 1-bit cells
// SHOULD use memcpy and bit_cpy for other sizes
#include "../core.h" #include "../core.h"
#include "../utils/talloc.h" #include "../utils/talloc.h"
#include "../utils/calls.h" #include "../utils/calls.h"
@ -32,17 +51,8 @@ static void allocBitGroups(B* rp, usz ria, B z, ur xr, usz* xsh, i32* len, usz w
else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? inc(z) : taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)); } else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? inc(z) : taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)); }
} }
extern B rt_group; // Integer list w
B group_c2(B t, B w, B x) { static B group_simple(B w, B x, ur xr, usz wia, usz xia, usz* xsh, u8 we) {
if (!isArr(x)) thrM("⊔: 𝕩 must be an array");
ur xr = RNK(x);
if (isArr(w) && RNK(w)==1 && xr>=1 && depth(w)==1) {
usz wia = IA(w);
usz* xsh = SH(x);
usz xia = *xsh;
if (wia-xia > 1) thrF("⊔: ≠𝕨 must be either ≠𝕩 or one bigger (%s≡≠𝕨, %s≡≠𝕩)", wia, xia);
u8 we = TI(w,elType);
if (elInt(we)) {
if (we==el_bit) w = taga(cpyI8Arr(w)); if (we==el_bit) w = taga(cpyI8Arr(w));
i64 ria = 0; i64 ria = 0;
bool bad = false, sort = true; bool bad = false, sort = true;
@ -97,6 +107,8 @@ B group_c2(B t, B w, B x) {
if (bits && xl>=3) { bits=0; width>>=3; } if (bits && xl>=3) { bits=0; width>>=3; }
if ((csz & (csz-1)) || xl>7) xl = 7; if ((csz & (csz-1)) || xl>7) xl = 7;
} }
// Few changes in 𝕨: move in chunks
if (xia>64 && notB && change<(xia*width)/32) { if (xia>64 && notB && change<(xia*width)/32) {
#define C1(F,X ) F##_c1(m_f64(0),X ) #define C1(F,X ) F##_c1(m_f64(0),X )
#define C2(F,X,W) F##_c2(m_f64(0),X,W) #define C2(F,X,W) F##_c2(m_f64(0),X,W)
@ -142,7 +154,10 @@ B group_c2(B t, B w, B x) {
} }
#undef GROUP_CHUNKED #undef GROUP_CHUNKED
decG(ind); decG(ind);
} else { goto done;
}
// Many ¯1s: filter out, then continue
if (xia>32 && neg>xia/4+xia/8) { if (xia>32 && neg>xia/4+xia/8) {
if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w); if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w);
B m = ne_c2(m_f64(0), m_f64(-1), inc(w)); B m = ne_c2(m_f64(0), m_f64(-1), inc(w));
@ -156,7 +171,7 @@ B group_c2(B t, B w, B x) {
for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check
u8 xk = xl - 3; u8 xk = xl - 3;
if (notB && sort) { if (notB && sort) { // Sorted 𝕨, that is, partition 𝕩
void* xp = tyany_ptr(x); void* xp = tyany_ptr(x);
u64 i=neg*width; u64 i=neg*width;
#define GROUP_SORT(CPY, ALLOC) \ #define GROUP_SORT(CPY, ALLOC) \
@ -176,7 +191,7 @@ B group_c2(B t, B w, B x) {
else GROUP_SORT(bit_cpy, rp[j] = taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l))) else GROUP_SORT(bit_cpy, rp[j] = taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)))
} }
#undef GROUP_SORT #undef GROUP_SORT
} else if (notB && xk <= 3) { } else if (notB && xk <= 3) { // Cells of 𝕩 are CPU-sized
void* xp = tyany_ptr(x); void* xp = tyany_ptr(x);
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz); allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
switch(xk) { default: UD; switch(xk) { default: UD;
@ -185,7 +200,7 @@ B group_c2(B t, B w, B x) {
case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break; case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break;
case 3: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u64*)tyarr_ptr(rp[n]))[pos[n]++] = ((u64*)xp)[i]; } break; case 3: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u64*)tyarr_ptr(rp[n]))[pos[n]++] = ((u64*)xp)[i]; } break;
} }
} else { } else { // Generic case
for (usz i = 0; i < ria; i++) { for (usz i = 0; i < ria; i++) {
usz l = len[i]; usz l = len[i];
Arr* c = m_fillarrp(l*csz); Arr* c = m_fillarrp(l*csz);
@ -212,12 +227,27 @@ B group_c2(B t, B w, B x) {
} }
for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz; for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz;
} }
}
done:
TFREE(pos); TFREE(pos);
setfill_dec_ret: setfill_dec_ret:
fillarr_setFill(rf, xf); fillarr_setFill(rf, xf);
decG(w); decG(x); decG(w); decG(x);
return taga(r); return taga(r);
}
extern B rt_group;
B group_c2(B t, B w, B x) {
if (!isArr(x)) thrM("⊔: 𝕩 must be an array");
ur xr = RNK(x);
if (isArr(w) && RNK(w)==1 && xr>=1 && depth(w)==1) {
usz wia = IA(w);
usz* xsh = SH(x);
usz xia = *xsh;
if (wia-xia > 1) thrF("⊔: ≠𝕨 must be either ≠𝕩 or one bigger (%s≡≠𝕨, %s≡≠𝕩)", wia, xia);
u8 we = TI(w,elType);
if (elInt(we)) {
return group_simple(w, x, xr, wia, xia, xsh, we);
} else if (xr==1) { } else if (xr==1) {
SLOW2("𝕨⊔𝕩", w, x); SLOW2("𝕨⊔𝕩", w, x);
SGetU(w) SGetU(w)