faster leading axis arithmetic

This commit is contained in:
dzaima 2023-04-30 00:53:01 +03:00
parent 2c207dbebc
commit 1e11cf93c1
3 changed files with 43 additions and 6 deletions

View File

@ -20,6 +20,7 @@ B atan2_c2(B, B, B);
typedef void (*AndBytesFn)(u8*, u8*, u64, u64);
B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr);
#if SINGELI_SIMD
#include "../singeli/c/arithdDispatch.c"
@ -276,8 +277,13 @@ static B modint_AS(B w, B xv) { return modint_AA(w, C2(shape, C1(fne, incG(w))
}
#define AR_I_AA(CHR, NAME, EXPR, BIT, EXTRA) NOINLINE B NAME##_AA(B t, B w, B x) { \
if (RNK(w)!=RNK(x)) goto bad; \
if (!eqShPart(SH(w), SH(x), RNK(w))) thrF(CHR ": Expected equal shape prefix (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", w, x); \
ur wr=RNK(w); usz* xsh=SH(x); \
ur xr=RNK(x); usz* wsh=SH(w); ur mr=wr<xr?wr:xr; \
if (!eqShPart(wsh, xsh, mr)) thrF(CHR ": Expected equal shape prefix (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", w, x); \
if (wr!=xr) { \
if (TI(w,elType)!=el_B && TI(x,elType)!=el_B) return leading_axis_arith(NAME##_c2, w, x, wsh, xsh, mr); \
else goto bad; \
} \
usz ia = IA(x); B r; \
u8 we = TI(w,elType); \
u8 xe = TI(x,elType); \

View File

@ -117,6 +117,23 @@ NOINLINE B toKCells(B x, ur k) {
return r;
}
NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { // assumes non-equal rank typed array arguments
assert(isArr(w) && isArr(x) && TI(w,elType)!=el_B && TI(x,elType)!=el_B);
ur wr = RNK(w);
#if DEBUG
assert(wr!=RNK(x) && eqShPart(wsh, xsh, mr));
#endif
usz cam2 = shProd(xsh, 0, mr);
usz* lsh = mr==wr? xsh : wsh;
B s = mr==wr? x : w;
M_APD_SH(r, mr, lsh);
S_KSLICES(s, lsh, mr, cam2, 1) usz sp=0;
if (mr==wr) { SGetU(w); for (usz i=0; i<cam2; i++) APDD(r, fc2(m_f64(0), GetU(w,i), SLICEI(s))); }
else { SGetU(x); for (usz i=0; i<cam2; i++) APDD(r, fc2(m_f64(0), SLICEI(s), GetU(x,i))); }
decG(w); decG(x);
return taga(APD_SH_GET(r, 0));
}
// fast special-case implementations
@ -518,6 +535,13 @@ NOINLINE B for_cells_AA(B f, B w, B x, ur wcr, ur xcr, u32 chr) {
decG(w); decG(x); return taga(r);
}
}
if (isPervasiveDy(f)) {
if (TI(w,elType)==el_B) goto generic; // for generic arrays it'd ¨ either way
if (TI(x,elType)==el_B) goto generic;
ur mr = xr<wr? xr : wr;
if ((wk>mr?mr:wk) != (xk>mr?mr:xk) || !eqShPart(wsh, xsh, mr)) goto generic;
return c2(f, w, x);
}
}
generic:;

View File

@ -62,7 +62,14 @@ NOINLINE B dyArith_AA(DyTableAA* table, B w, B x) {
u8 xe = TI(x, elType); if (xe==el_B) goto rec;
ur wr = RNK(w);
ur xr = RNK(x);
if (wr!=xr || !eqShPart(SH(w), SH(x), wr)) goto rec;
if (wr!=xr) {
usz* xsh=SH(x);
usz* wsh=SH(w);
ur mr=wr<xr?wr:xr;
if (!eqShPart(wsh, xsh, mr)) goto rec;
return leading_axis_arith(table->mainFn, w, x, wsh, xsh, mr);
}
if (!eqShPart(SH(w), SH(x), wr)) goto rec;
B r, t;
usz ia = IA(w);