Extend Group where 𝕨 is an integer list to handle higher-rank 𝕩
This commit is contained in:
parent
21033fa355
commit
3505e1515e
@ -3,6 +3,7 @@
|
||||
#include "../utils/mut.h"
|
||||
#include "../builtins.h"
|
||||
|
||||
extern B ud_c1(B, B);
|
||||
extern B ne_c2(B, B, B);
|
||||
extern B slash_c1(B, B);
|
||||
extern B slash_c2(B, B, B);
|
||||
@ -11,11 +12,29 @@ extern B take_c2(B, B, B);
|
||||
extern B drop_c2(B, B, B);
|
||||
extern B join_c2(B, B, B);
|
||||
|
||||
static Arr* arr_shChangeLen(Arr* a, ur r, usz* xsh, usz len) {
|
||||
assert(r > 1);
|
||||
usz* sh = a->sh = m_shArr(r)->a;
|
||||
SPRNK(a,r);
|
||||
sh[0] = len;
|
||||
shcpy(sh+1, xsh+1, r-1);
|
||||
return a;
|
||||
}
|
||||
static B m_shChangeLen(u8 xt, ur xr, usz* xsh, usz l, usz cw, usz csz) {
|
||||
return taga(arr_shChangeLen(m_arr(offsetof(TyArr, a)+l*cw, xt, l*csz), xr, xsh, l));
|
||||
}
|
||||
static void allocGroups(B* rp, usz ria, B z, u8 xt, ur xr, usz* xsh, i32* len, usz width, usz csz) {
|
||||
if (xr==1) for (usz j = 0; j < ria; j++) { usz l=len[j]; if (!l) rp[j] = inc(z); else m_tyarrv(rp+j, width, l, xt); }
|
||||
else for (usz j = 0; j < ria; j++) { usz l=len[j]; rp[j] = !l ? inc(z) : m_shChangeLen(xt, xr, xsh, l, width, csz); }
|
||||
}
|
||||
|
||||
extern B rt_group;
|
||||
B group_c2(B t, B w, B x) {
|
||||
if (isArr(w)&isArr(x) && RNK(w)==1 && RNK(x)==1 && depth(w)==1) {
|
||||
ur xr = RNK(x);
|
||||
if (isArr(w)&isArr(x) && RNK(w)==1 && xr>=1 && depth(w)==1) {
|
||||
usz wia = IA(w);
|
||||
usz xia = IA(x);
|
||||
usz* xsh = SH(x);
|
||||
usz xia = *xsh;
|
||||
if (wia-xia > 1) thrF("⊔: ≠𝕨 must be either ≠𝕩 or one bigger (%s≡≠𝕨, %s≡≠𝕩)", wia, xia);
|
||||
u8 we = TI(w,elType);
|
||||
if (elInt(we)) {
|
||||
@ -48,7 +67,8 @@ B group_c2(B t, B w, B x) {
|
||||
for (usz i = 0; i < ria; i++) rp[i] = m_f64(0); // don't break if allocation errors
|
||||
B xf = getFillQ(x);
|
||||
|
||||
Arr* rf = arr_shVec(m_fillarrp(0)); fillarr_setFill(rf, m_f64(0));
|
||||
Arr* rf = m_fillarrp(0); if (xr==1) arr_shVec(rf); else arr_shChangeLen(rf, xr, xsh, 0);
|
||||
fillarr_setFill(rf, m_f64(0));
|
||||
B z = taga(rf);
|
||||
fillarr_setFill(r, z);
|
||||
|
||||
@ -60,10 +80,19 @@ B group_c2(B t, B w, B x) {
|
||||
goto intvec_ret;
|
||||
}
|
||||
|
||||
u8 xe = TI(x,elType);
|
||||
u8 width = elWidth(xe);
|
||||
u64 xw;
|
||||
if (xia>64 && (xw=(u64)xia*width)<=I32_MAX && change<xw/32 && xe!=el_bit && xe!=el_B) {
|
||||
bool notB = TI(x,elType) != el_B;
|
||||
u8 xt = arrNewType(TY(x));
|
||||
u8 xl = arrTypeBitsLog(TY(x));
|
||||
bool bits = xl == 0;
|
||||
u64 width = bits ? 1 : 1<<(xl-3); // cell width in bits if bits==1, bytes otherwise
|
||||
usz csz = 1;
|
||||
if (RARE(xr>1)) {
|
||||
width *= csz = arr_csz(x);
|
||||
xl += CTZ(csz);
|
||||
if (bits && xl>=3) { bits=1; width>>=3; }
|
||||
if ((csz & (csz-1)) || xl>7) xl = 7;
|
||||
}
|
||||
if (xia>64 && notB && !bits && change<(xia*width)/32) {
|
||||
#define C1(F,X ) F##_c1(m_f64(0),X )
|
||||
#define C2(F,X,W) F##_c2(m_f64(0),X,W)
|
||||
if (wia>xia) w = C2(take, m_f64(xia), w);
|
||||
@ -86,86 +115,83 @@ B group_c2(B t, B w, B x) {
|
||||
for (usz i = 0; i < wia; i++) len[wp[i]]+=ip[i];
|
||||
|
||||
void* xp = tyany_ptr(x);
|
||||
u8 xt = el2t(xe);
|
||||
|
||||
for (usz j = 0; j < ria; j++) {
|
||||
usz l=len[j];
|
||||
if (!l) rp[j]=inc(z); else m_tyarrv(rp+j, width, l, xt);
|
||||
}
|
||||
for (usz i=0, k=i0*width; i<wia; i++) {
|
||||
usz k0 = k;
|
||||
usz l = ip[i]*width; k += l;
|
||||
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
|
||||
for (u64 i=0, k=i0*width; i<wia; i++) {
|
||||
u64 k0 = k;
|
||||
u64 l = ip[i]*width; k += l;
|
||||
i32 n = wp[i]; if (n<0) continue;
|
||||
memcpy((u8*)tyarr_ptr(rp[n])+pos[n], (u8*)xp+k0, l);
|
||||
pos[n] += l;
|
||||
}
|
||||
decG(ind); goto intvec_ret;
|
||||
}
|
||||
if (xia>32 && neg>xia/4+xia/8) {
|
||||
if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w);
|
||||
B m = ne_c2(m_f64(0), m_f64(-1), inc(w));
|
||||
w = slash_c2(m_f64(0), inc(m), w);
|
||||
x = slash_c2(m_f64(0), m, x); xia = IA(x);
|
||||
neg = 0;
|
||||
}
|
||||
if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w));
|
||||
i32* wp = i32any_ptr(w);
|
||||
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
|
||||
for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check
|
||||
|
||||
switch (xe) { default: UD;
|
||||
case el_i8: case el_c8:
|
||||
case el_i16: case el_c16:
|
||||
case el_i32: case el_c32: case el_f64: {
|
||||
void* xp = tyany_ptr(x);
|
||||
u8 xt = el2t(xe);
|
||||
if (sort) {
|
||||
for (usz j=0, i=neg*width; j<ria; j++) {
|
||||
usz l = len[j];
|
||||
if (!l) { rp[j]=inc(z); continue; }
|
||||
m_tyarrv(rp+j, width, l, xt);
|
||||
usz lw = l*width;
|
||||
memcpy(tyarr_ptr(rp[j]), (u8*)xp+i, lw);
|
||||
i += lw;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
for (usz j = 0; j < ria; j++) {
|
||||
usz l=len[j];
|
||||
if (!l) rp[j]=inc(z); else m_tyarrv(rp+j, width, l, xt);
|
||||
}
|
||||
switch(width) { default: UD;
|
||||
case 1: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break;
|
||||
case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break;
|
||||
case 4: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break;
|
||||
case 8: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((f64*)tyarr_ptr(rp[n]))[pos[n]++] = ((f64*)xp)[i]; } break;
|
||||
}
|
||||
break;
|
||||
decG(ind);
|
||||
} else {
|
||||
if (xia>32 && neg>xia/4+xia/8) {
|
||||
if (wia>xia) w = take_c2(m_f64(0), m_f64(xia), w);
|
||||
B m = ne_c2(m_f64(0), m_f64(-1), inc(w));
|
||||
w = slash_c2(m_f64(0), inc(m), w);
|
||||
x = slash_c2(m_f64(0), m, x); xia = IA(x);
|
||||
neg = 0;
|
||||
}
|
||||
case el_bit: case el_B: {
|
||||
if (TI(w,elType)!=el_i32) w = taga(cpyI32Arr(w));
|
||||
i32* wp = i32any_ptr(w);
|
||||
for (usz i = 0; i < ria; i++) len[i] = pos[i] = 0;
|
||||
for (usz i = 0; i < xia; i++) len[wp[i]]++; // overallocation makes this safe after n<-1 check
|
||||
|
||||
u8 xk = xl - 3;
|
||||
if (notB && !bits && sort) {
|
||||
void* xp = tyany_ptr(x);
|
||||
u64 i=neg*width; for (usz j=0; j<ria; j++) {
|
||||
usz l = len[j];
|
||||
if (!l) { rp[j]=inc(z); continue; }
|
||||
if (xr==1) m_tyarrv(rp+j, width, l, xt);
|
||||
else rp[j] = m_shChangeLen(xt, xr, xsh, l, width, csz);
|
||||
u64 lw = l*width;
|
||||
memcpy(tyarr_ptr(rp[j]), (u8*)xp+i, lw);
|
||||
i += lw;
|
||||
}
|
||||
} else if (notB && xk <= 3) {
|
||||
void* xp = tyany_ptr(x);
|
||||
allocGroups(rp, ria, z, xt, xr, xsh, len, width, csz);
|
||||
switch(xk) { default: UD;
|
||||
case 0: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u8* )tyarr_ptr(rp[n]))[pos[n]++] = ((u8* )xp)[i]; } break;
|
||||
case 1: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u16*)tyarr_ptr(rp[n]))[pos[n]++] = ((u16*)xp)[i]; } break;
|
||||
case 2: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((u32*)tyarr_ptr(rp[n]))[pos[n]++] = ((u32*)xp)[i]; } break;
|
||||
case 3: for (usz i = 0; i < xia; i++) { i32 n = wp[i]; if (n>=0) ((f64*)tyarr_ptr(rp[n]))[pos[n]++] = ((f64*)xp)[i]; } break;
|
||||
}
|
||||
} else {
|
||||
for (usz i = 0; i < ria; i++) {
|
||||
Arr* c = m_fillarrp(len[i]);
|
||||
usz l = len[i];
|
||||
Arr* c = m_fillarrp(l*csz);
|
||||
c->ia = 0;
|
||||
fillarr_setFill(c, inc(xf));
|
||||
arr_shVec(c);
|
||||
if (xr==1) arr_shVec(c); else arr_shChangeLen(c, xr, xsh, l);
|
||||
rp[i] = taga(c);
|
||||
}
|
||||
SLOW2("𝕨⊔𝕩", w, x);
|
||||
SGet(x)
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
i32 n = wp[i];
|
||||
if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i);
|
||||
if (csz == 1) {
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
i32 n = wp[i];
|
||||
if (n>=0) fillarr_ptr(a(rp[n]))[pos[n]++] = Get(x, i);
|
||||
}
|
||||
} else {
|
||||
for (usz i = 0; i < xia; i++) {
|
||||
i32 n = wp[i];
|
||||
if (n<0) continue;
|
||||
usz p = (pos[n]++)*csz;
|
||||
B* rnp = fillarr_ptr(a(rp[n])) + p;
|
||||
for (usz j = 0; j < csz; j++) rnp[j] = Get(x, i*csz + j);
|
||||
}
|
||||
}
|
||||
for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i];
|
||||
break;
|
||||
for (usz i = 0; i < ria; i++) a(rp[i])->ia = len[i]*csz;
|
||||
}
|
||||
}
|
||||
intvec_ret:
|
||||
fillarr_setFill(rf, xf);
|
||||
decG(w); decG(x); TFREE(pos);
|
||||
return taga(r);
|
||||
} else {
|
||||
} else if (xr==1) {
|
||||
SLOW2("𝕨⊔𝕩", w, x);
|
||||
SGetU(w)
|
||||
i64 ria = wia==xia? 0 : o2i64(GetU(w, xia));
|
||||
@ -213,7 +239,6 @@ B group_c2(B t, B w, B x) {
|
||||
base:
|
||||
return c2(rt_group, w, x);
|
||||
}
|
||||
B ud_c1(B, B);
|
||||
B group_c1(B t, B x) {
|
||||
if (isArr(x) && RNK(x)==1 && TI(x,arrD1)) {
|
||||
usz ia = IA(x);
|
||||
@ -222,4 +247,3 @@ B group_c1(B t, B x) {
|
||||
}
|
||||
return c1(rt_group, x);
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user