faster leading axis arithmetic

This commit is contained in:
dzaima 2023-04-30 00:53:01 +03:00
parent 2c207dbebc
commit 1e11cf93c1
3 changed files with 43 additions and 6 deletions

View File

@ -20,6 +20,7 @@ B atan2_c2(B, B, B);
typedef void (*AndBytesFn)(u8*, u8*, u64, u64);
B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr);
#if SINGELI_SIMD
#include "../singeli/c/arithdDispatch.c"
@ -181,7 +182,7 @@ static B modint_AS(B w, B xv) { return modint_AA(w, C2(shape, C1(fne, incG(w))
[el_i16] = 0x0001000100010001ULL,
[el_i32] = 0x0000000100000001ULL,
};
GC2f("|", stile, pfmod(x.f, w.f), NOUNROLL,
GC2f("|", stile, pfmod(x.f, w.f), NOUNROLL,
/*INT_SA*/
if (q_i32(w)) {
i32 wi32 = o2iG(w);
@ -276,8 +277,13 @@ static B modint_AS(B w, B xv) { return modint_AA(w, C2(shape, C1(fne, incG(w))
}
#define AR_I_AA(CHR, NAME, EXPR, BIT, EXTRA) NOINLINE B NAME##_AA(B t, B w, B x) { \
if (RNK(w)!=RNK(x)) goto bad; \
if (!eqShPart(SH(w), SH(x), RNK(w))) thrF(CHR ": Expected equal shape prefix (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", w, x); \
ur wr=RNK(w); usz* xsh=SH(x); \
ur xr=RNK(x); usz* wsh=SH(w); ur mr=wr<xr?wr:xr; \
if (!eqShPart(wsh, xsh, mr)) thrF(CHR ": Expected equal shape prefix (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", w, x); \
if (wr!=xr) { \
if (TI(w,elType)!=el_B && TI(x,elType)!=el_B) return leading_axis_arith(NAME##_c2, w, x, wsh, xsh, mr); \
else goto bad; \
} \
usz ia = IA(x); B r; \
u8 we = TI(w,elType); \
u8 xe = TI(x,elType); \

View File

@ -117,6 +117,23 @@ NOINLINE B toKCells(B x, ur k) {
return r;
}
NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { // assumes non-equal rank typed array arguments
assert(isArr(w) && isArr(x) && TI(w,elType)!=el_B && TI(x,elType)!=el_B);
ur wr = RNK(w);
#if DEBUG
assert(wr!=RNK(x) && eqShPart(wsh, xsh, mr));
#endif
usz cam2 = shProd(xsh, 0, mr);
usz* lsh = mr==wr? xsh : wsh;
B s = mr==wr? x : w;
M_APD_SH(r, mr, lsh);
S_KSLICES(s, lsh, mr, cam2, 1) usz sp=0;
if (mr==wr) { SGetU(w); for (usz i=0; i<cam2; i++) APDD(r, fc2(m_f64(0), GetU(w,i), SLICEI(s))); }
else { SGetU(x); for (usz i=0; i<cam2; i++) APDD(r, fc2(m_f64(0), SLICEI(s), GetU(x,i))); }
decG(w); decG(x);
return taga(APD_SH_GET(r, 0));
}
// fast special-case implementations
@ -518,12 +535,19 @@ NOINLINE B for_cells_AA(B f, B w, B x, ur wcr, ur xcr, u32 chr) {
decG(w); decG(x); return taga(r);
}
}
if (isPervasiveDy(f)) {
if (TI(w,elType)==el_B) goto generic; // for generic arrays it'd ¨ either way
if (TI(x,elType)==el_B) goto generic;
ur mr = xr<wr? xr : wr;
if ((wk>mr?mr:wk) != (xk>mr?mr:xk) || !eqShPart(wsh, xsh, mr)) goto generic;
return c2(f, w, x);
}
}
generic:;
M_APD_SH(r, zk, zsh);
S_KSLICES(w, wsh, wk, xkM? cam0 : cam, 1) usz wp = 0;
S_KSLICES(x, xsh, xk, xkM? cam : cam0, 1) usz xp = 0;
S_KSLICES(w, wsh, wk, xkM? cam0 : cam, 1) usz wp=0;
S_KSLICES(x, xsh, xk, xkM? cam : cam0, 1) usz xp=0;
FC2 fc2 = c2fn(f);
if (ext==1) { for (usz i=0; i<cam; i++) APDD(r, fc2(f, SLICEI(w), SLICEI(x))); }
else if (xkM) { for (usz i=0; i<cam; ) { B wb=incByG(SLICEI(w), ext-1); for (usz e = i+ext; i < e; i++) APDD(r, fc2(f, wb, SLICEI(x))); } }

View File

@ -62,7 +62,14 @@ NOINLINE B dyArith_AA(DyTableAA* table, B w, B x) {
u8 xe = TI(x, elType); if (xe==el_B) goto rec;
ur wr = RNK(w);
ur xr = RNK(x);
if (wr!=xr || !eqShPart(SH(w), SH(x), wr)) goto rec;
if (wr!=xr) {
usz* xsh=SH(x);
usz* wsh=SH(w);
ur mr=wr<xr?wr:xr;
if (!eqShPart(wsh, xsh, mr)) goto rec;
return leading_axis_arith(table->mainFn, w, x, wsh, xsh, mr);
}
if (!eqShPart(SH(w), SH(x), wr)) goto rec;
B r, t;
usz ia = IA(w);