And strided height, so all orders are fast if axes are long

This commit is contained in:
Marshall Lochbaum 2023-03-29 11:31:52 -04:00
parent e923a71881
commit d492cd0865

View File

@ -257,30 +257,50 @@ B transp_c2(B t, B w, B x) {
r = withFill(taga(ra), getFillQ(x));
decG(x); goto ret;
}
#undef AXIS_LOOP
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
&& na>=2 && st[na-2]==1 // Last axis ends up second-to-last
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // Large-ish axes
&& xe!=el_B) {
&& xe!=el_B && na>=2) {
// If some result axis has stride 1 (guaranteed if dup==0), then it
// corresponds to the last argument axis and we have a strided
// transpose swapping that with the last result axis
usz rai = na-1;
usz xai=rai; while (st[--xai]!=1) if (xai==0) goto skip_2d;
if (rsh[xai]*rsh[rai] < (256*8) >> xlw) goto skip_2d;
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
usz rf = shProd(rsh, 0, na);
Arr* ra;
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
u8* xp = tyany_ptr(x);
usz w = rsh[na-2]; usz ws = st[na-1];
usz h = rsh[na-1];
usz w = rsh[xai]; usz ws = st[rai];
usz h = rsh[rai]; usz hs = shProd(rsh, xai+1, rai) * h;
if (na == 2) {
tran(rp, xp, w, h, ws, h);
tran(rp, xp, w, h, ws, hs);
} else {
csz = (csz<<xlw) / 8; // Convert to bytes
usz ria = rf*csz;
usz hw = h*w*csz;
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h, ws, h));
usz i_skip = (w-1)*hs*csz;
usz end = rf*csz - i_skip;
ur a0 = na - 1;
for (usz i=0; i<na; i++) st[i] *= csz;
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0;
for (usz i=0, j=0;;) {
tran(rp+i, xp+j, w, h, ws, hs);
i += h*csz;
if (i == end) break;
for (ur a = a0;;) {
if (--a == xai) { assert(a!=0); i+=i_skip; --a; }
usz str = st[a];
j += str;
if (LIKELY(++ri[a] < rsh[a])) break;
ri[a] = 0;
j -= rsh[a] * str;
}
}
}
shSet(ra, rr, sh);
r = taga(ra);
decG(x); goto ret;
}
#undef AXIS_LOOP
skip_2d:;
// Reshape x for selection
ShArr* zsh = m_shArr(2);