Split out optimized Group cases and add comments

This commit is contained in:
Marshall Lochbaum 2022-11-09 17:51:25 -05:00
parent 801472d6d4
commit 6eb504118e

View File

@ -1,3 +1,22 @@
// Group and Group Indices (⊔)
// Group Indices: calls 𝕩⊔↕𝕩 for rank-1 flat 𝕩, otherwise self-hosted
// Group: native code for rank-1 𝕨 only, optimizations for integers
// SHOULD squeeze 𝕨
// SHOULD handle boolean 𝕨 with replicate
// COULD handle small-range 𝕨 with equals-replicate
// All statistics computed in the initial pass that finds ⌈´𝕨
// If +´»⊸≠𝕨 is small, process in chunks as a separate case
// If +´𝕨<¯1 is large, filter out ¯1s.
// COULD recompute statistics, may have enabled chunked or sorted code
// If ∧´1↓»⊸<𝕨, that is, ∧⊸≡𝕨, each result array is a slice of 𝕩
// COULD use slice types; seems dangerous--when will they be freed?
// Remaining cases copy cells from 𝕩 individually
// CPU-sized cells handled quickly
// SHOULD use bit ops for 1-bit cells
// SHOULD use memcpy and bit_cpy for other sizes
#include "../core.h"
#include "../utils/talloc.h"
#include "../utils/calls.h"
@ -32,6 +51,191 @@ static void allocBitGroups(B* rp, usz ria, B z, ur xr, usz* xsh, i32* len, usz w
else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? inc(z) : taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)); }
}
// Integer list w
static B group_simple(B w, B x, ur xr, usz wia, usz xia, usz* xsh, u8 we) {
if (we==el_bit) w = taga(cpyI8Arr(w));
i64 ria = 0;
bool bad = false, sort = true;
usz neg = 0, change = 0;
void *wp0 = tyany_ptr(w);
#define CASE(T) case el_##T: { \
T max = -1, prev = -1; \
for (usz i = 0; i < xia; i++) { \
T n = ((T*)wp0)[i]; \
if (n>max) max = n; \
bad |= n < -1; \
neg += n == -1; \
sort &= prev <= n; \
change += prev != n; \
prev = n; \
} \
if (wia>xia) { ria=((T*)wp0)[xia]; bad|=ria<-1; } \
i64 m=(i64)max+1; if (m>ria) ria=m; \
break; }
switch (we) { default:UD; case el_bit: CASE(i8) CASE(i16) CASE(i32) }
#undef CASE
if (bad) thrM("⊔: 𝕨 can't contain elements less than ¯1");
if (ria > (i64)(USZ_MAX)) thrOOM();
Arr* r = arr_shVec(m_fillarrp(ria)); fillarr_setFill(r, m_f64(0));
B* rp = fillarr_ptr(r);
for (usz i = 0; i < ria; i++) rp[i] = m_f64(0); // don't break if allocation errors
B xf = getFillQ(x);
Arr* rf = m_fillarrp(0); if (xr==1) arr_shVec(rf); else arr_shChangeLen(rf, xr, xsh, 0);
fillarr_setFill(rf, m_f64(0));
B z = taga(rf);
fillarr_setFill(r, z);
// Both cases needed to make sure wia>0 for ip[wia-1] below
if (ria==0) goto setfill_dec_ret;
if (neg==xia) {
for (usz i = 0; i < ria; i++) rp[i] = inc(z);
goto setfill_dec_ret;
}
TALLOC(i32, pos, 2*ria+1); i32* len = pos+ria+1;
bool notB = TI(x,elType) != el_B;
u8 xt = arrNewType(TY(x));
u8 xl = arrTypeBitsLog(TY(x));
bool bits = xl == 0;
u64 width = bits ? 1 : 1<<(xl-3); // cell width in bits if bits==1, bytes otherwise
usz csz = 1;
if (RARE(xr>1)) {
width *= csz = arr_csz(x);
xl += CTZ(csz);
if (bits && xl>=3) { bits=0; width>>=3; }
if ((csz & (csz-1)) || xl>7) xl = 7;
}
// Few changes in 𝕨: move in chunks
if (xia>64 && notB && change<(xia*width)/32) {
#define C1(F,X ) F##_c1(m_f64(0),X )
#define C2(F,X,W) F##_c2(m_f64(0),X,W)
u64* mp; B m = m_bitarrv(&mp, xia);
u8* wp0 = tyany_ptr(w);
we = TI(w,elType);
CMP_AA_IMM(ne, we, mp, wp0-elWidth(we), wp0, xia);
bitp_set(mp, 0, -1!=o2fG(IGetU(w,0)));
B ind = C1(slash, m);
w = C2(select, inc(ind), w);
#undef C1
#undef C2
if (TI(ind,elType)!=el_i32) ind = taga(cpyI32Arr(ind));
if (TI(w ,elType)!=el_i32) w = taga(cpyI32Arr(w ));
wia = IA(ind);
i32* ip = i32any_ptr(ind);
i32* wp = i32any_ptr(w);
usz i0 = ip[0];
for (usz i=0; i<wia-1; i++) ip[i] = ip[i+1]-ip[i];
ip[wia-1] = xia-ip[wia-1];
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
for (usz i = 0; i < wia; i++) len[wp[i]]+=ip[i];
void* xp = tyany_ptr(x);
#define GROUP_CHUNKED(CPY) \
for (u64 i=0, k=i0*width; i<wia; i++) { \
u64 k0 = k; \
u64 l = ip[i]*width; k += l; \
i32 n = wp[i]; if (n<0) continue; \
CPY(tyarr_ptr(rp[n]), pos[n], xp, k0, l); \
pos[n] += l; \
}
if (!bits) {
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
GROUP_CHUNKED(MEM_CPY)
} else {
allocBitGroups(rp, ria, z, xr, xsh, len, width);
GROUP_CHUNKED(bit_cpy)
}
#undef GROUP_CHUNKED
decG(ind);
goto done;
}
// Many ¯1s: filter out, then continue
if (xia>32 && neg>xia/4+xia/8) {
if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w);
B m = ne_c2(m_f64(0), m_f64(-1), inc(w));
w = slash_c2(m_f64(0), inc(m), w);
x = slash_c2(m_f64(0), m, x); xia = IA(x);
neg = 0;
}
if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w));
i32* wp = i32any_ptr(w);
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check
u8 xk = xl - 3;
if (notB && sort) { // Sorted 𝕨, that is, partition 𝕩
void* xp = tyany_ptr(x);
u64 i=neg*width;
#define GROUP_SORT(CPY, ALLOC) \
for (usz j=0; j<ria; j++) { \
usz l = len[j]; \
if (!l) { rp[j]=inc(z); continue; } \
ALLOC; \
u64 lw = l*width; \
CPY(tyarr_ptr(rp[j]), 0, xp, i, lw); \
i += lw; \
}
if (!bits) {
if (xr==1) GROUP_SORT(MEM_CPY, m_tyarrv(rp+j, width, l, xt))
else GROUP_SORT(MEM_CPY, rp[j] = m_shChangeLen(xt, xr, xsh, l, width, csz))
} else {
if (xr==1) GROUP_SORT(bit_cpy, rp[j] = taga(arr_shVec(m_bitarr_nop(l))))
else GROUP_SORT(bit_cpy, rp[j] = taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)))
}
#undef GROUP_SORT
} else if (notB && xk <= 3) { // Cells of 𝕩 are CPU-sized
void* xp = tyany_ptr(x);
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
switch(xk) { default: UD;
case 0: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break;
case 1: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break;
case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break;
case 3: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u64*)tyarr_ptr(rp[n]))[pos[n]++] = ((u64*)xp)[i]; } break;
}
} else { // Generic case
for (usz i = 0; i < ria; i++) {
usz l = len[i];
Arr* c = m_fillarrp(l*csz);
c->ia = 0;
fillarr_setFill(c, inc(xf));
if (xr==1) arr_shVec(c); else arr_shChangeLen(c, xr, xsh, l);
rp[i] = taga(c);
}
SLOW2("𝕨⊔𝕩", w, x);
SGet(x)
if (csz == 1) {
for (usz i = 0; i < xia; i++) {
i32 n = wp[i];
if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i);
}
} else {
for (usz i = 0; i < xia; i++) {
i32 n = wp[i];
if (n<0) continue;
usz p = (pos[n]++)*csz;
B* rnp = fillarr_ptr(a(rp[n])) + p;
for (usz j = 0; j < csz; j++) rnp[j] = Get(x, i*csz + j);
}
}
for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz;
}
done:
TFREE(pos);
setfill_dec_ret:
fillarr_setFill(rf, xf);
decG(w); decG(x);
return taga(r);
}
extern B rt_group;
B group_c2(B t, B w, B x) {
if (!isArr(x)) thrM("⊔: 𝕩 must be an array");
@ -43,181 +247,7 @@ B group_c2(B t, B w, B x) {
if (wia-xia > 1) thrF("⊔: ≠𝕨 must be either ≠𝕩 or one bigger (%s≡≠𝕨, %s≡≠𝕩)", wia, xia);
u8 we = TI(w,elType);
if (elInt(we)) {
if (we==el_bit) w = taga(cpyI8Arr(w));
i64 ria = 0;
bool bad = false, sort = true;
usz neg = 0, change = 0;
void *wp0 = tyany_ptr(w);
#define CASE(T) case el_##T: { \
T max = -1, prev = -1; \
for (usz i = 0; i < xia; i++) { \
T n = ((T*)wp0)[i]; \
if (n>max) max = n; \
bad |= n < -1; \
neg += n == -1; \
sort &= prev <= n; \
change += prev != n; \
prev = n; \
} \
if (wia>xia) { ria=((T*)wp0)[xia]; bad|=ria<-1; } \
i64 m=(i64)max+1; if (m>ria) ria=m; \
break; }
switch (we) { default:UD; case el_bit: CASE(i8) CASE(i16) CASE(i32) }
#undef CASE
if (bad) thrM("⊔: 𝕨 can't contain elements less than ¯1");
if (ria > (i64)(USZ_MAX)) thrOOM();
Arr* r = arr_shVec(m_fillarrp(ria)); fillarr_setFill(r, m_f64(0));
B* rp = fillarr_ptr(r);
for (usz i = 0; i < ria; i++) rp[i] = m_f64(0); // don't break if allocation errors
B xf = getFillQ(x);
Arr* rf = m_fillarrp(0); if (xr==1) arr_shVec(rf); else arr_shChangeLen(rf, xr, xsh, 0);
fillarr_setFill(rf, m_f64(0));
B z = taga(rf);
fillarr_setFill(r, z);
// Both cases needed to make sure wia>0 for ip[wia-1] below
if (ria==0) goto setfill_dec_ret;
if (neg==xia) {
for (usz i = 0; i < ria; i++) rp[i] = inc(z);
goto setfill_dec_ret;
}
TALLOC(i32, pos, 2*ria+1); i32* len = pos+ria+1;
bool notB = TI(x,elType) != el_B;
u8 xt = arrNewType(TY(x));
u8 xl = arrTypeBitsLog(TY(x));
bool bits = xl == 0;
u64 width = bits ? 1 : 1<<(xl-3); // cell width in bits if bits==1, bytes otherwise
usz csz = 1;
if (RARE(xr>1)) {
width *= csz = arr_csz(x);
xl += CTZ(csz);
if (bits && xl>=3) { bits=0; width>>=3; }
if ((csz & (csz-1)) || xl>7) xl = 7;
}
if (xia>64 && notB && change<(xia*width)/32) {
#define C1(F,X ) F##_c1(m_f64(0),X )
#define C2(F,X,W) F##_c2(m_f64(0),X,W)
u64* mp; B m = m_bitarrv(&mp, xia);
u8* wp0 = tyany_ptr(w);
we = TI(w,elType);
CMP_AA_IMM(ne, we, mp, wp0-elWidth(we), wp0, xia);
bitp_set(mp, 0, -1!=o2fG(IGetU(w,0)));
B ind = C1(slash, m);
w = C2(select, inc(ind), w);
#undef C1
#undef C2
if (TI(ind,elType)!=el_i32) ind = taga(cpyI32Arr(ind));
if (TI(w ,elType)!=el_i32) w = taga(cpyI32Arr(w ));
wia = IA(ind);
i32* ip = i32any_ptr(ind);
i32* wp = i32any_ptr(w);
usz i0 = ip[0];
for (usz i=0; i<wia-1; i++) ip[i] = ip[i+1]-ip[i];
ip[wia-1] = xia-ip[wia-1];
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
for (usz i = 0; i < wia; i++) len[wp[i]]+=ip[i];
void* xp = tyany_ptr(x);
#define GROUP_CHUNKED(CPY) \
for (u64 i=0, k=i0*width; i<wia; i++) { \
u64 k0 = k; \
u64 l = ip[i]*width; k += l; \
i32 n = wp[i]; if (n<0) continue; \
CPY(tyarr_ptr(rp[n]), pos[n], xp, k0, l); \
pos[n] += l; \
}
if (!bits) {
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
GROUP_CHUNKED(MEM_CPY)
} else {
allocBitGroups(rp, ria, z, xr, xsh, len, width);
GROUP_CHUNKED(bit_cpy)
}
#undef GROUP_CHUNKED
decG(ind);
} else {
if (xia>32 && neg>xia/4+xia/8) {
if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w);
B m = ne_c2(m_f64(0), m_f64(-1), inc(w));
w = slash_c2(m_f64(0), inc(m), w);
x = slash_c2(m_f64(0), m, x); xia = IA(x);
neg = 0;
}
if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w));
i32* wp = i32any_ptr(w);
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check
u8 xk = xl - 3;
if (notB && sort) {
void* xp = tyany_ptr(x);
u64 i=neg*width;
#define GROUP_SORT(CPY, ALLOC) \
for (usz j=0; j<ria; j++) { \
usz l = len[j]; \
if (!l) { rp[j]=inc(z); continue; } \
ALLOC; \
u64 lw = l*width; \
CPY(tyarr_ptr(rp[j]), 0, xp, i, lw); \
i += lw; \
}
if (!bits) {
if (xr==1) GROUP_SORT(MEM_CPY, m_tyarrv(rp+j, width, l, xt))
else GROUP_SORT(MEM_CPY, rp[j] = m_shChangeLen(xt, xr, xsh, l, width, csz))
} else {
if (xr==1) GROUP_SORT(bit_cpy, rp[j] = taga(arr_shVec(m_bitarr_nop(l))))
else GROUP_SORT(bit_cpy, rp[j] = taga(arr_shChangeLen(m_bitarr_nop(l*width), xr, xsh, l)))
}
#undef GROUP_SORT
} else if (notB && xk <= 3) {
void* xp = tyany_ptr(x);
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
switch(xk) { default: UD;
case 0: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break;
case 1: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break;
case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break;
case 3: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u64*)tyarr_ptr(rp[n]))[pos[n]++] = ((u64*)xp)[i]; } break;
}
} else {
for (usz i = 0; i < ria; i++) {
usz l = len[i];
Arr* c = m_fillarrp(l*csz);
c->ia = 0;
fillarr_setFill(c, inc(xf));
if (xr==1) arr_shVec(c); else arr_shChangeLen(c, xr, xsh, l);
rp[i] = taga(c);
}
SLOW2("𝕨⊔𝕩", w, x);
SGet(x)
if (csz == 1) {
for (usz i = 0; i < xia; i++) {
i32 n = wp[i];
if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i);
}
} else {
for (usz i = 0; i < xia; i++) {
i32 n = wp[i];
if (n<0) continue;
usz p = (pos[n]++)*csz;
B* rnp = fillarr_ptr(a(rp[n])) + p;
for (usz j = 0; j < csz; j++) rnp[j] = Get(x, i*csz + j);
}
}
for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz;
}
}
TFREE(pos);
setfill_dec_ret:
fillarr_setFill(rf, xf);
decG(w); decG(x);
return taga(r);
return group_simple(w, x, xr, wia, xia, xsh, we);
} else if (xr==1) {
SLOW2("𝕨⊔𝕩", w, x);
SGetU(w)