handle more cases of ⊏⎉n & ⊑⎉n

This commit is contained in:
dzaima 2023-04-22 17:03:09 +03:00
parent fe071b641f
commit 36b99d3505

View File

@ -110,54 +110,54 @@ NOINLINE B toKCells(B x, ur k) {
// fast special-case implementations
static NOINLINE B select_cells(usz n, B x, ur xr) {
static NOINLINE B select_cells(usz n, B x, usz cam, usz k, bool leaf) { // n {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
ur xr = RNK(x);
assert(xr>1 && k<xr);
usz* xsh = SH(x);
B r;
usz cam = xsh[0];
if (xr==2) {
usz csz = xsh[1];
if (csz==1) return taga(arr_shVec(TI(x,slice)(x,0,IA(x))));
usz csz = shProd(xsh, k+1, xr);
usz take = leaf? 1 : csz;
usz jump = xsh[k] * csz;
assert(cam*jump == IA(x));
Arr* ra;
if (take==jump) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else if (take==1) {
u8 xe = TI(x,elType);
if (xe==el_B) {
SGet(x)
HArr_p rp = m_harrUv(cam);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*csz+n);
NOGC_E; r=rp.b;
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*jump+n);
NOGC_E; ra = (Arr*)rp.c;
} else {
void* rp = m_tyarrv(&r, elWidth(xe), cam, el2t(xe));
void* rp = m_tyarrp(&ra, elWidth(xe), cam, el2t(xe));
void* xp = tyany_ptr(x);
switch(xe) {
case el_bit: for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*csz+n)); break;
case el_i8: case el_c8: PLAINLOOP for (usz i=0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*csz+n]; break;
case el_i16: case el_c16: PLAINLOOP for (usz i=0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*csz+n]; break;
case el_i32: case el_c32: PLAINLOOP for (usz i=0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*csz+n]; break;
case el_f64: PLAINLOOP for (usz i=0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*csz+n]; break;
case el_bit: for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*jump+n)); break;
case el_i8: case el_c8: PLAINLOOP for (usz i=0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*jump+n]; break;
case el_i16: case el_c16: PLAINLOOP for (usz i=0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*jump+n]; break;
case el_i32: case el_c32: PLAINLOOP for (usz i=0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*jump+n]; break;
case el_f64: PLAINLOOP for (usz i=0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*jump+n]; break;
}
}
} else {
Arr* ra;
if (xsh[1]==1) {
ra = TI(x,slice)(incG(x), 0, IA(x));
} else {
usz rs = shProd(xsh, 2, xr);
usz xs = rs*xsh[1]; // aka csz
MAKE_MUT_INIT(rm, cam*rs, TI(x,elType)); MUTG_INIT(rm);
usz xi = rs*n;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, rs);
xi+= xs;
ri+= rs;
}
ra = mut_fp(rm);
MAKE_MUT_INIT(rm, cam*take, TI(x,elType)); MUTG_INIT(rm);
usz xi = take*n;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, take);
xi+= jump;
ri+= take;
}
usz* rsh = arr_shAlloc(ra, xr-1);
shcpy(rsh+1, xsh+2, xr-2);
rsh[0] = cam;
r = taga(ra);
ra = mut_fp(rm);
}
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
if (rsh) {
shcpy(rsh, xsh, k);
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
}
decG(x);
return r;
return taga(ra);
}
static NOINLINE B shift_cells(B f, B x, usz cam, usz csz, u8 e, u8 rtid) { // »⎉1 or «⎉1
@ -289,13 +289,12 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
return k==1 && RNK(x)>1? toCells(x) : k==0? m_atomUnit(x) : toKCells(x, k);
case n_select:
if (IA(x)==0) goto noSpecial;
if (cr==0 || k!=1) goto base; // TODO handle more ranks
selectCells:;
return select_cells(0, x, xr);
if (cr==0) goto base;
return select_cells(0, x, cam, k, false);
case n_pick:
if (IA(x)==0) goto noSpecial;
if (k!=1 || cr!=1 || !TI(x,arrD1)) goto base; // TODO handle more ranks
goto selectCells;
if (cr==0 || !TI(x,arrD1)) goto base;
return select_cells(0, x, cam, k, true);
case n_couple: {
Arr* r = cpyWithShape(x); xsh=PSH(r);
if (xr==UR_MAX) thrF("≍%c: Result rank too large (%i≡=𝕩)", chr, xr);
@ -333,7 +332,7 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x, with arr
if (rtid==n_const) { f=fd->f; goto const_f; }
if ((rtid==n_fold || rtid==n_insert) && TI(x,elType)!=el_B && k==1 && xr==2 && isPervasiveDyExt(fd->f)) {
usz *sh = SH(x); usz m = sh[1];
if (m == 1) return select_cells(0, x, 2);
if (m == 1) return select_cells(0, x, cam, k, false);
if (m <= 64 && m < sh[0]) return fold_rows(fd, x);
}
}
@ -435,8 +434,8 @@ B cell_c2(Md1D* d, B w, B x) { B f = d->f;
if (cam==0) return cell2_empty(f, w, x, wr, xr);
if (isFun(f)) {
u8 rtid = v(f)->flags-1;
if (rtid==n_select && isF64(w) && xr>1) return select_cells(WRAP(o2i64(w), SH(x)[1], thrF("⊏: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, cam)), x, xr);
if (rtid==n_pick && TI(x,arrD1) && xr>1 && isF64(w)) return select_cells(WRAP(o2i64(w), SH(x)[1], thrF("⊑: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, cam)), x, xr);
if (rtid==n_select && isF64(w) && xr==2) return select_cells(WRAP(o2i64(w), SH(x)[1], thrF("⊏: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, cam)), x, cam, 1, false);
if (rtid==n_pick && TI(x,arrD1) && xr==2 && isF64(w)) return select_cells(WRAP(o2i64(w), SH(x)[1], thrF("⊑: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, cam)), x, cam, 1, true);
if ((rtid==n_shifta || rtid==n_shiftb) && xr==2) {
if (isArr(w)) { B w0=w; w = IGet(w,0); decG(w0); }
return shift_cells(w, x, SH(x)[0], SH(x)[1], el_or(TI(x,elType), selfElType(w)), rtid);