And strided height, so all orders are fast if axes are long
This commit is contained in:
parent
e923a71881
commit
d492cd0865
@ -257,30 +257,50 @@ B transp_c2(B t, B w, B x) {
|
|||||||
r = withFill(taga(ra), getFillQ(x));
|
r = withFill(taga(ra), getFillQ(x));
|
||||||
decG(x); goto ret;
|
decG(x); goto ret;
|
||||||
}
|
}
|
||||||
|
#undef AXIS_LOOP
|
||||||
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
|
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
|
||||||
&& na>=2 && st[na-2]==1 // Last axis ends up second-to-last
|
&& xe!=el_B && na>=2) {
|
||||||
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // Large-ish axes
|
// If some result axis has stride 1 (guaranteed if dup==0), then it
|
||||||
&& xe!=el_B) {
|
// corresponds to the last argument axis and we have a strided
|
||||||
|
// transpose swapping that with the last result axis
|
||||||
|
usz rai = na-1;
|
||||||
|
usz xai=rai; while (st[--xai]!=1) if (xai==0) goto skip_2d;
|
||||||
|
if (rsh[xai]*rsh[rai] < (256*8) >> xlw) goto skip_2d;
|
||||||
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
|
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
|
||||||
usz rf = shProd(rsh, 0, na);
|
usz rf = shProd(rsh, 0, na);
|
||||||
Arr* ra;
|
Arr* ra;
|
||||||
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
|
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
|
||||||
u8* xp = tyany_ptr(x);
|
u8* xp = tyany_ptr(x);
|
||||||
usz w = rsh[na-2]; usz ws = st[na-1];
|
usz w = rsh[xai]; usz ws = st[rai];
|
||||||
usz h = rsh[na-1];
|
usz h = rsh[rai]; usz hs = shProd(rsh, xai+1, rai) * h;
|
||||||
if (na == 2) {
|
if (na == 2) {
|
||||||
tran(rp, xp, w, h, ws, h);
|
tran(rp, xp, w, h, ws, hs);
|
||||||
} else {
|
} else {
|
||||||
csz = (csz<<xlw) / 8; // Convert to bytes
|
csz = (csz<<xlw) / 8; // Convert to bytes
|
||||||
usz ria = rf*csz;
|
usz i_skip = (w-1)*hs*csz;
|
||||||
usz hw = h*w*csz;
|
usz end = rf*csz - i_skip;
|
||||||
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h, ws, h));
|
ur a0 = na - 1;
|
||||||
|
for (usz i=0; i<na; i++) st[i] *= csz;
|
||||||
|
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0;
|
||||||
|
for (usz i=0, j=0;;) {
|
||||||
|
tran(rp+i, xp+j, w, h, ws, hs);
|
||||||
|
i += h*csz;
|
||||||
|
if (i == end) break;
|
||||||
|
for (ur a = a0;;) {
|
||||||
|
if (--a == xai) { assert(a!=0); i+=i_skip; --a; }
|
||||||
|
usz str = st[a];
|
||||||
|
j += str;
|
||||||
|
if (LIKELY(++ri[a] < rsh[a])) break;
|
||||||
|
ri[a] = 0;
|
||||||
|
j -= rsh[a] * str;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
shSet(ra, rr, sh);
|
shSet(ra, rr, sh);
|
||||||
r = taga(ra);
|
r = taga(ra);
|
||||||
decG(x); goto ret;
|
decG(x); goto ret;
|
||||||
}
|
}
|
||||||
#undef AXIS_LOOP
|
skip_2d:;
|
||||||
|
|
||||||
// Reshape x for selection
|
// Reshape x for selection
|
||||||
ShArr* zsh = m_shArr(2);
|
ShArr* zsh = m_shArr(2);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user