And strided height, so all orders are fast if axes are long
This commit is contained in:
parent
e923a71881
commit
d492cd0865
@ -257,30 +257,50 @@ B transp_c2(B t, B w, B x) {
|
||||
r = withFill(taga(ra), getFillQ(x));
|
||||
decG(x); goto ret;
|
||||
}
|
||||
#undef AXIS_LOOP
|
||||
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
|
||||
&& na>=2 && st[na-2]==1 // Last axis ends up second-to-last
|
||||
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // Large-ish axes
|
||||
&& xe!=el_B) {
|
||||
&& xe!=el_B && na>=2) {
|
||||
// If some result axis has stride 1 (guaranteed if dup==0), then it
|
||||
// corresponds to the last argument axis and we have a strided
|
||||
// transpose swapping that with the last result axis
|
||||
usz rai = na-1;
|
||||
usz xai=rai; while (st[--xai]!=1) if (xai==0) goto skip_2d;
|
||||
if (rsh[xai]*rsh[rai] < (256*8) >> xlw) goto skip_2d;
|
||||
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
|
||||
usz rf = shProd(rsh, 0, na);
|
||||
Arr* ra;
|
||||
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
|
||||
u8* xp = tyany_ptr(x);
|
||||
usz w = rsh[na-2]; usz ws = st[na-1];
|
||||
usz h = rsh[na-1];
|
||||
usz w = rsh[xai]; usz ws = st[rai];
|
||||
usz h = rsh[rai]; usz hs = shProd(rsh, xai+1, rai) * h;
|
||||
if (na == 2) {
|
||||
tran(rp, xp, w, h, ws, h);
|
||||
tran(rp, xp, w, h, ws, hs);
|
||||
} else {
|
||||
csz = (csz<<xlw) / 8; // Convert to bytes
|
||||
usz ria = rf*csz;
|
||||
usz hw = h*w*csz;
|
||||
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h, ws, h));
|
||||
usz i_skip = (w-1)*hs*csz;
|
||||
usz end = rf*csz - i_skip;
|
||||
ur a0 = na - 1;
|
||||
for (usz i=0; i<na; i++) st[i] *= csz;
|
||||
usz* ri = st+ar; for (usz i=0; i<a0; i++) ri[i]=0;
|
||||
for (usz i=0, j=0;;) {
|
||||
tran(rp+i, xp+j, w, h, ws, hs);
|
||||
i += h*csz;
|
||||
if (i == end) break;
|
||||
for (ur a = a0;;) {
|
||||
if (--a == xai) { assert(a!=0); i+=i_skip; --a; }
|
||||
usz str = st[a];
|
||||
j += str;
|
||||
if (LIKELY(++ri[a] < rsh[a])) break;
|
||||
ri[a] = 0;
|
||||
j -= rsh[a] * str;
|
||||
}
|
||||
}
|
||||
}
|
||||
shSet(ra, rr, sh);
|
||||
r = taga(ra);
|
||||
decG(x); goto ret;
|
||||
}
|
||||
#undef AXIS_LOOP
|
||||
skip_2d:;
|
||||
|
||||
// Reshape x for selection
|
||||
ShArr* zsh = m_shArr(2);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user