Use strided width in transpose-based Reorder Axes

This commit is contained in:
Marshall Lochbaum 2023-03-29 08:22:15 -04:00
parent 814a677676
commit e923a71881

View File

@ -258,23 +258,23 @@ B transp_c2(B t, B w, B x) {
decG(x); goto ret;
}
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
&& na>=2 && st[na-1]==rsh[na-2] && st[na-2]==1 // Last two axes transposed
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // And large-ish
&& na>=2 && st[na-2]==1 // Last axis ends up second-to-last
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // Large-ish axes
&& xe!=el_B) {
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
usz rf = shProd(rsh, 0, na);
Arr* ra;
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
u8* xp = tyany_ptr(x);
usz w = rsh[na-2];
usz w = rsh[na-2]; usz ws = st[na-1];
usz h = rsh[na-1];
if (na == 2) {
tran(rp, xp, w, h, w, h);
tran(rp, xp, w, h, ws, h);
} else {
csz = (csz<<xlw) / 8; // Convert to bytes
usz ria = rf*csz;
usz hw = h*w*csz;
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h, w, h));
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h, ws, h));
}
shSet(ra, rr, sh);
r = taga(ra);