SIMD kernel for Reorder Axes transposing last two axes
This commit is contained in:
parent
8413461074
commit
5b17ee44d6
@ -32,6 +32,7 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
typedef void (*TranspFn)(void*,void*,u64,u64);
|
||||
#if SINGELI
|
||||
#define transposeFns simd_transpose
|
||||
#define DECL_BASE(T) \
|
||||
@ -50,7 +51,7 @@
|
||||
}
|
||||
DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64)
|
||||
#undef DECL_BASE
|
||||
static void (*transposeFns[])(void*,void*,u64,u64) = {
|
||||
static TranspFn transposeFns[] = {
|
||||
transpose_i8, transpose_i16, transpose_i32, transpose_i64
|
||||
};
|
||||
#endif
|
||||
@ -227,36 +228,65 @@ B transp_c2(B t, B w, B x) {
|
||||
|
||||
u8 xe = TI(x,elType);
|
||||
usz csz = shProd(xsh, ar, xr);
|
||||
if (csz >= (32*8) >> elWidthLogBits(xe)) { // cell >= 32 bytes
|
||||
usz ria = csz * shProd(rsh, 0, na);
|
||||
MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm);
|
||||
for (usz i=0; i<na; i++) st[i] *= csz;
|
||||
TALLOC(usz, ri, na-1); for (usz i=0; i<na-1; i++) ri[i]=0;
|
||||
for (usz i=0, j=0;;) {
|
||||
// Hardcode one innermost loop: we know a>=0
|
||||
ur a = na - 1;
|
||||
usz str = st[a];
|
||||
usz l = rsh[a];
|
||||
for (usz k=0; k<l; k++) {
|
||||
mut_copyG(rm, i, x, j+k*str, csz);
|
||||
i += csz;
|
||||
}
|
||||
if (i == ria) break;
|
||||
// Update result index, starting with last axis finished
|
||||
while (1) {
|
||||
str = st[--a];
|
||||
j += str;
|
||||
if (LIKELY(++ri[a] < rsh[a])) break;
|
||||
ri[a] = 0;
|
||||
j -= rsh[a] * str;
|
||||
}
|
||||
}
|
||||
#define AXIS_LOOP(INIT, N_AX, I_INC, DO_INNER) \
|
||||
ur a0 = N_AX - 1; \
|
||||
for (usz i=0; i<na; i++) st[i] *= csz; \
|
||||
TALLOC(usz, ri, a0); for (usz i=0; i<a0; i++) ri[i]=0; \
|
||||
INIT \
|
||||
for (usz i=0, j0=0;;) { \
|
||||
/* Hardcode one innermost loop: assume N_AX>=1 */ \
|
||||
ur a = a0; \
|
||||
usz str = st[a]; \
|
||||
usz l = rsh[a]; \
|
||||
for (usz k=0; k<l; k++) { \
|
||||
usz j=j0+k*str; DO_INNER; \
|
||||
i += I_INC; \
|
||||
} \
|
||||
if (i == ria) break; \
|
||||
/* Update result index starting with last axis finished */\
|
||||
while (1) { \
|
||||
str = st[--a]; \
|
||||
j0 += str; \
|
||||
if (LIKELY(++ri[a] < rsh[a])) break; \
|
||||
ri[a] = 0; \
|
||||
j0 -= rsh[a] * str; \
|
||||
} \
|
||||
} \
|
||||
TFREE(ri);
|
||||
u8 xlw = elWidthLogBits(xe);
|
||||
if (csz >= (32*8) >> xlw) { // cell >= 32 bytes
|
||||
usz ria = csz * shProd(rsh, 0, na);
|
||||
AXIS_LOOP(MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm); ,
|
||||
na, csz, mut_copyG(rm, i, x, j, csz));
|
||||
Arr* ra = mut_fp(rm);
|
||||
shSet_alloc(ra, rr, rsh);
|
||||
r = withFill(taga(ra), getFillQ(x));
|
||||
decG(x); goto ret_decst;
|
||||
}
|
||||
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
|
||||
&& !dup && na>=2 && p[na-1]==ar-2 && p[na-2]==ar-1 // Last two axes transposed
|
||||
&& rsh[na-2]*rsh[na-1] >= (256*8) >> xlw // And large-ish
|
||||
&& xe!=el_B) {
|
||||
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
|
||||
usz rf = shProd(rsh, 0, na);
|
||||
Arr* ra;
|
||||
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
|
||||
u8* xp = tyany_ptr(x);
|
||||
usz w = rsh[na-2];
|
||||
usz h = rsh[na-1];
|
||||
if (na == 2) {
|
||||
tran(rp, xp, w, h);
|
||||
} else {
|
||||
csz = (csz<<xlw) / 8; // Convert to bytes
|
||||
usz ria = rf*csz;
|
||||
usz hw = h*w*csz;
|
||||
AXIS_LOOP(, na-2, hw, tran(rp+i, xp+j, w, h));
|
||||
}
|
||||
shSet_alloc(ra, rr, rsh);
|
||||
r = taga(ra);
|
||||
decG(x); goto ret_decst;
|
||||
}
|
||||
#undef AXIS_LOOP
|
||||
|
||||
// Reshape x for selection, collapsing ar axes
|
||||
if (ar != 1) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user