baseline inds⊸⊏˘ mat

This commit is contained in:
dzaima 2024-07-26 22:15:59 +03:00
parent f7dd900b3a
commit 0d7bf86182
2 changed files with 138 additions and 4 deletions

View File

@ -9,6 +9,7 @@ B shape_c2(B, B, B);
B transp_c2(B, B, B);
B take_c2(B, B, B);
B join_c2(B, B, B);
B select_c2(B, B, B);
// from fold.c:
B fold_rows(Md1D* d, B x, usz n, usz m);
@ -19,6 +20,8 @@ B insert_cells_identity(B x, B f, usz* xsh, ur xr, ur k, u8 rtid);
B scan_rows_bit(u8, B x, usz m); // from scan.c
B takedrop_highrank(bool take, B w, B x); // from sfns.c
B rotate_highrank(bool inv, B w, B x); // from sfns.c
B select_rows_B(B x, ux csz, ux cam, B inds); // from select.c
B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c
// X - variable name; XSH - its shape; K - number of leading axes that get iterated over; SLN - number of slices that will be made; DX - additional refcount count to add to x
@ -47,6 +50,12 @@ B try_interleave_cells(B w, B x, ur xr, ur xk, usz* xsh); // from transpose.c
#define SLICEI(X) ({ B r = SLICE(X, X##p); X##p+= X##_csz; r; })
Arr* customizeShape(B x) { // potentially copy array for shape customizing
if (reusable(x) && RNK(x)<=1) return a(x);
return TI(x,slice)(x,0,IA(x));
}
B insert_base(B f, B x, bool has_w, B w) { // Used by Insert in fold.c
assert(isArr(x) && RNK(x)>0);
@ -68,6 +77,21 @@ B insert_base(B f, B x, bool has_w, B w) { // Used by Insert in fold.c
return r;
}
B select_cells_base(B inds, B x0, ux csz, ux cam) { // consumes inds,x0; Used by select.c
assert(cam!=0);
Arr* xa = customizeShape(x0);
usz* xsh = arr_shAlloc(xa, 2);
xsh[0] = cam;
xsh[1] = csz;
B x = taga(xa);
assert(RNK(x)==2);
S_KSLICES(x, xsh, 1, cam, 0) incBy(inds, cam-1);
usz shBuf[] = {cam};
M_APD_SH_N(r, 1, shBuf, cam);
for (usz i=0,xp=0; i<cam; i++) APDD(r, C2(select, inds, SLICEI(x)));
return taga(APD_SH_GET(r, '\0'));
}
B scan_arith(B f, B w, B x, usz* xsh) { // Used by scan.c
bool has_w = w.u != m_f64(0).u;
assert(isArr(x) && (!has_w || isArr(w)));
@ -667,10 +691,22 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr
switch(rtid) {
case n_rtack: dec(w); return x;
case n_ltack: return const_cells(x, xk, xsh, w, chr);
case n_select: if (isF64(w) && xcr>=1) {
usz l = xsh[xk];
return select_cells(WRAP(o2i64(w), l, thrF("⊏: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, false);
} break;
case n_select:
if (isArr(w) && RNK(w)==1 && xcr==1) {
assert(xr > 1);
ux wia = IA(w);
ShArr* rsh = m_shArr(xr);
shcpy(rsh->a, xsh, xk);
rsh->a[xk] = wia;
Arr* r = customizeShape(select_rows_B(x, shProd(xsh,xk,xr), cam, w));
arr_shSetUG(r, xr, rsh);
return taga(r);
}
if (isF64(w) && xcr>=1) {
usz l = xsh[xk];
return select_cells(WRAP(o2i64(w), l, thrF("⊏: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, false);
}
break;
case n_pick: if (isF64(w) && xcr==1 && TI(x,arrD1)) {
usz l = xsh[xk];
return select_cells(WRAP(o2i64(w), l, thrF("⊑: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, true);

View File

@ -536,6 +536,104 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz cam, usz csz) { // consu
#undef FREE_CHECK
}
static void* m_tyarrv_same(B* r, usz ia, B src) { // makes a new typed array with same element type as src, but new ia
u8 se = TI(src,elType); assert(se!=el_bit && se!=el_B);
return m_tyarrlv(r, arrTypeWidthLog(TY(src)), ia, arrNewType(TY(src)));
}
B slash_c2(B, B, B);
Arr* customizeShape(B x); // from cells.c
B select_cells_base(B inds, B x0, ux csz, ux cam);
B select_rows_typed(B x, ux csz, ux cam, void* inds, ux indn, u8 ie, bool shouldBoundsCheck) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊z; xe cannot be el_bit or el_B, unless csz==1; ie must be ≤el_i8 if csz≤128
assert(csz!=0 && cam!=0);
assert(csz*cam == IA(x));
assert(ie<=el_i32);
if (csz==1) { // TODO maybe move to select_rows_B and require csz>=2 here?
i64 bounds[2];
if (!getRange_fns[ie](inds, bounds, indn) || bounds[0]<-1 || bounds[1]>0) goto generic; // could put under shouldBoundsCheck but ideally things setting that to false should handle size-1 cells themselves
return C2(slash, m_f64(indn), taga(arr_shVec(customizeShape(x))));
}
u8 xe = TI(x,elType);
assert(xe!=el_bit && xe!=el_B);
assert(csz>=2);
B r;
u8 lb = arrTypeWidthLog(TY(x));
ux xbump = csz<<lb;
ux rbump = indn<<lb;
u8* xp = tyany_ptr(x);
if (ie==el_bit) {
if (true /*csz>=8 || indn>=32*/) { // TODO enable & properly tune
u8* rp = m_tyarrv_same(&r, indn * cam, x);
for (ux i = 0; i < cam; i++) {
bitselFns[lb](rp, inds, loadu_u64(xp), loadu_u64(xp + (1<<lb)), indn);
xp+= xbump;
rp+= rbump;
}
goto decG_ret;
} else {
thrM("TODO widen");
}
}
#if SINGELI
{
u8* rp = m_tyarrv_same(&r, indn * cam, x);
ux slow_cam = cam;
SimdSelectFn fn = SIMD_SELECT(ie, lb+3);
for (ux i = 0; i < slow_cam; i++) {
fn(inds, xp, rp, indn, csz);
xp+= xbump;
rp+= rbump;
}
goto decG_ret;
}
#endif
generic:;
B indo = taga(arr_shVec(m_tyslice(inds, a(emptyIVec()), ie, indn)));
r = select_cells_base(indo, x, csz, cam);
return r;
decG_ret:;
decG(x);
return r;
}
B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸⊏˘ cam‿csz⥊x
assert(csz*cam == IA(x));
if (csz==0) goto generic;
if (cam<=1) {
if (cam==0) return taga(emptyArr(x, 1));
return C2(select, inds, taga(arr_shVec(TI(x,slice)(x, 0, IA(x)))));
}
ux in = IA(inds);
u8 ie = TI(inds,elType);
if (csz<=2? ie!=el_bit : csz<128? ie>el_i8 : !elInt(ie)) {
inds = num_squeeze(inds);
ie = TI(inds,elType);
if (!elInt(ie)) goto generic;
}
void* ip = tyany_ptr(inds);
u8 xe = TI(x,elType);
if ((xe!=el_bit && xe!=el_B) || csz==1) {
B r = select_rows_typed(x, csz, cam, (u8*)ip, in, ie, 1);
decG(inds);
return r;
}
generic:;
return select_cells_base(inds, x, csz, cam);
}
B select_ucw(B t, B o, B w, B x) {
if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
usz xia = IA(x);