replicate-using leading axis arithmetic case

This commit is contained in:
dzaima 2023-04-30 18:28:58 +03:00
parent 1e11cf93c1
commit c889a07d75

View File

@ -117,21 +117,36 @@ NOINLINE B toKCells(B x, ur k) {
return r;
}
NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { // assumes non-equal rank typed array arguments
B slash_c2(B, B, B);
NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { // assumes non-equal rank conforming typed array arguments
assert(isArr(w) && isArr(x) && TI(w,elType)!=el_B && TI(x,elType)!=el_B);
ur wr = RNK(w);
ur xr = RNK(x);
#if DEBUG
assert(wr!=RNK(x) && eqShPart(wsh, xsh, mr));
assert(wr!=xr && (mr==wr || mr==xr) && eqShPart(wsh, xsh, mr));
#endif
usz cam2 = shProd(xsh, 0, mr);
usz* lsh = mr==wr? xsh : wsh;
B s = mr==wr? x : w;
M_APD_SH(r, mr, lsh);
S_KSLICES(s, lsh, mr, cam2, 1) usz sp=0;
if (mr==wr) { SGetU(w); for (usz i=0; i<cam2; i++) APDD(r, fc2(m_f64(0), GetU(w,i), SLICEI(s))); }
else { SGetU(x); for (usz i=0; i<cam2; i++) APDD(r, fc2(m_f64(0), SLICEI(s), GetU(x,i))); }
decG(w); decG(x);
return taga(APD_SH_GET(r, 0));
usz cam = shProd(xsh, 0, mr);
B b = mr==wr? x : w; // bigger argument
usz* bsh = mr==wr? xsh : wsh;
ur br = wr>xr? wr : xr;
usz csz = shProd(bsh, mr, br);
if (csz<200) {
B s = mr==wr? w : x; // smaller argument
s = C2(slash, m_usz(csz), taga(arr_shVec(TI(s,slice)(s,0,IA(s)))));
assert(reusable(s) && RNK(s)==1);
arr_shCopy(a(s), b);
if (mr==wr) w=s; else x=s;
return fc2(m_f64(0), w, x);
} else {
M_APD_SH(r, mr, bsh);
S_KSLICES(b, bsh, mr, cam, 1) usz bp=0;
if (mr==wr) { SGetU(w); for (usz i=0; i<cam; i++) APDD(r, fc2(m_f64(0), GetU(w,i), SLICEI(b))); }
else { SGetU(x); for (usz i=0; i<cam; i++) APDD(r, fc2(m_f64(0), SLICEI(b), GetU(x,i))); }
decG(w); decG(x);
return taga(APD_SH_GET(r, 0));
}
}
@ -536,8 +551,7 @@ NOINLINE B for_cells_AA(B f, B w, B x, ur wcr, ur xcr, u32 chr) {
}
}
if (isPervasiveDy(f)) {
if (TI(w,elType)==el_B) goto generic; // for generic arrays it'd ¨ either way
if (TI(x,elType)==el_B) goto generic;
if (TI(w,elType)==el_B || TI(x,elType)==el_B) goto generic;
ur mr = xr<wr? xr : wr;
if ((wk>mr?mr:wk) != (xk>mr?mr:xk) || !eqShPart(wsh, xsh, mr)) goto generic;
return c2(f, w, x);