Make transpose movement functions strided

This commit is contained in:
Marshall Lochbaum 2023-03-28 21:52:13 -04:00
parent de18fb996b
commit 814a677676
2 changed files with 35 additions and 40 deletions

View File

@ -32,25 +32,20 @@
#endif #endif
#endif #endif
typedef void (*TranspFn)(void*,void*,u64,u64);
#if SINGELI
#define transposeFns simd_transpose
#define DECL_BASE(T) \ #define DECL_BASE(T) \
static NOINLINE void base_transpose_##T(T* rp, T* xp, u64 bw, u64 bh, u64 w, u64 h) { \ static NOINLINE void transpose_##T(void* rv, void* xv, u64 bw, u64 bh, u64 w, u64 h) { \
T* rp=rv; T* xp=xv; \
PLAINLOOP for(usz y=0;y<bh;y++) NOVECTORIZE for(usz x=0;x<bw;x++) rp[x*h+y] = xp[y*w+x]; \ PLAINLOOP for(usz y=0;y<bh;y++) NOVECTORIZE for(usz x=0;x<bw;x++) rp[x*h+y] = xp[y*w+x]; \
} }
DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64) DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64)
#undef DECL_BASE #undef DECL_BASE
typedef void (*TranspFn)(void*,void*,u64,u64,u64,u64);
#if SINGELI
#define transposeFns simd_transpose
#define SINGELI_FILE transpose #define SINGELI_FILE transpose
#include "../utils/includeSingeli.h" #include "../utils/includeSingeli.h"
#else #else
#define DECL_BASE(T) \
static NOINLINE void transpose_##T(void* rv, void* xv, u64 w, u64 h) { \
T* rp=rv; T* xp=xv; usz xi=0; \
PLAINLOOP for(usz y=0;y< h;y++) NOVECTORIZE for(usz x=0;x< w;x++) rp[x*h+y] = xp[xi++]; \
}
DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64)
#undef DECL_BASE
static TranspFn transposeFns[] = { static TranspFn transposeFns[] = {
transpose_i8, transpose_i16, transpose_i32, transpose_i64 transpose_i8, transpose_i16, transpose_i32, transpose_i64
}; };
@ -59,7 +54,7 @@ typedef void (*TranspFn)(void*,void*,u64,u64);
static void transpose_move(void* rv, void* xv, u8 xe, usz w, usz h) { static void transpose_move(void* rv, void* xv, u8 xe, usz w, usz h) {
assert(xe!=el_bit); assert(xe!=el_B); assert(xe!=el_bit); assert(xe!=el_B);
transposeFns[elWidthLogBits(xe)-3](rv, xv, w, h); transposeFns[elWidthLogBits(xe)-3](rv, xv, w, h, w, h);
} }
// Return an array with data from x transposed as though it's shape h,w // Return an array with data from x transposed as though it's shape h,w
// Shape of result needs to be set afterwards! // Shape of result needs to be set afterwards!
@ -274,12 +269,12 @@ B transp_c2(B t, B w, B x) {
usz w = rsh[na-2]; usz w = rsh[na-2];
usz h = rsh[na-1]; usz h = rsh[na-1];
if (na == 2) { if (na == 2) {
tran(rp, xp, w, h); tran(rp, xp, w, h, w, h);
} else { } else {
csz = (csz<<xlw) / 8; // Convert to bytes csz = (csz<<xlw) / 8; // Convert to bytes
usz ria = rf*csz; usz ria = rf*csz;
usz hw = h*w*csz; usz hw = h*w*csz;
AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h)); AXIS_LOOP(na-2, hw, tran(rp+i, xp+j, w, h, w, h));
} }
shSet(ra, rr, sh); shSet(ra, rr, sh);
r = taga(ra); r = taga(ra);

View File

@ -65,8 +65,8 @@ def for_mult_max{k, m}{vars,begin,end,block} = {
} }
} }
def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = { def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
def at{x,y} = tup{xp + y*w + x, rp + x*h + y} def at{x,y} = tup{xp + y*ws + x, rp + x*hs + y}
# Cache line info # Cache line info
def line_bytes = 64 def line_bytes = 64
@ -79,7 +79,7 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = {
we := w; if (use_overlap{wo}) we += k - wo we := w; if (use_overlap{wo}) we += k - wo
wm := w - k wm := w - k
if (line_elts > 2*k or h&(line_elts-1) != 0) { if (line_elts > 2*k or h&(line_elts-1) != 0 or h != hs) {
ho := h%k ho := h%k
# Effective height, like we for w # Effective height, like we for w
he := h; if (use_overlap{ho}) he += k - ho he := h; if (use_overlap{ho}) he += k - ho
@ -88,7 +88,7 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = {
# Main transpose # Main transpose
@for_mult_max{kh, h-kh} (y to he) { @for_mult_max{kh, h-kh} (y to he) {
@for_mult_max{k, wm} (x to we) { @for_mult_max{k, wm} (x to we) {
kernel{...at{x,y}, k, kh, w, h} kernel{...at{x,y}, k, kh, ws, hs}
} }
} }
# Half-row(s) for non-square i16 case # Half-row(s) for non-square i16 case
@ -98,12 +98,12 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = {
@for (yi to n) { @for (yi to n) {
y:u64 = 0; if (yi == n-1) y = h - e y:u64 = 0; if (yi == n-1) y = h - e
@for_mult_max{k, wm} (x to we) { @for_mult_max{k, wm} (x to we) {
kernel{...at{x,y}, k, k, w, h} kernel{...at{x,y}, k, k, ws, hs}
} }
} }
} }
# Base transpose used if overlap wasn't # Base transpose used if overlap wasn't
if (ho!=0 and he==h) { hs := h-ho; call_base{rp+hs, xp+w*hs, w, ho} } if (ho!=0 and he==h) { hd := h-ho; call_base{rp+hd, xp+ws*hd, w, ho} }
} else { } else {
# Result rows are aligned with each other so it's possible to # Result rows are aligned with each other so it's possible to
# write a full cache line at a time # write a full cache line at a time
@ -118,57 +118,57 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = {
each{tup, ...each{vt, iota{line_vecs}}} each{tup, ...each{vt, iota{line_vecs}}}
} }
ro := tail{6, -u64~~rp} / (width{T}/8) # Offset to align within cache line; assume elt-aligned ro := tail{6, -u64~~rp} / (width{T}/8) # Offset to align within cache line; assume elt-aligned
wh := w*h wh := ws*h
yn := h yn := h
if (ro != 0) { if (ro != 0) {
ra := line_elts - ro ra := line_elts - ro
y := h - ra y := h - ra
rpe := rp + y + (w-1)*h # Cache aligned rpe := rp + y + (w-1)*hs # Cache aligned
# Part of first and last result row aren't covered by the split loop # Part of first and last result row aren't covered by the split loop
def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, w*i}} def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, ws*i}}
trtail{rp, xp, ro} trtail{rp, xp, ro}
trtail{rpe, xp + y*w + w-1, ra} trtail{rpe, xp + y*ws + w-1, ra}
# Transpose first few rows and last few rows together # Transpose first few rows and last few rows together
@for_mult_max{k, wm} (x to we) { @for_mult_max{k, wm} (x to we) {
{xpo,rpo} := at{x, y} {xpo,rpo} := at{x, y}
o := w*y + x o := ws*y + x
def loadx{_} = { def loadx{_} = {
l:=load{*VT~~(xp+o)} l:=load{*VT~~(xp+o)}
o+=w; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1 o+=ws; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1
l l
} }
def rls = get_lines{loadx} # 4 rows of 2 vectors each def rls = get_lines{loadx} # 4 rows of 2 vectors each
each{{i,v} => {p:=rpo+i*h; if (i<3 or p<rpe) store_line{*VT~~p, v}}, iota{k}, rls} each{{i,v} => {p:=rpo+i*hs; if (i<3 or p<rpe) store_line{*VT~~p, v}}, iota{k}, rls}
} }
--yn # One strip handled --yn # One strip handled
} }
@for_mult{line_elts} (y0 to yn) { y := y0 + ro @for_mult{line_elts} (y0 to yn) { y := y0 + ro
@for_mult_max{k, wm} (x to we) { @for_mult_max{k, wm} (x to we) {
{xpo,rpo} := at{x, y} {xpo,rpo} := at{x, y}
def rls = get_lines{{i} => load{*VT~~(xpo+i*w), 0}} def rls = get_lines{{i} => load{*VT~~(xpo+i*ws), 0}}
each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls} each{{i,v} => store_line{*VT~~(rpo+i*hs), v}, iota{k}, rls}
} }
} }
} }
if (we==w) @for(ws from w-wo to w) { if (we==w) @for(wd from w-wo to w) {
xpo:=xp+ws; rpo:=rp+h*ws xpo:=xp+wd; rpo:=rp+hs*wd
@for (i to h) store{rpo, i, load{xpo, w*i}} @for (i to h) store{rpo, i, load{xpo, ws*i}}
} }
} }
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = { fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
# Scalar transpose defined in C # Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
def call_base{...a} = emit{void, merge{'base_transpose_',ts}, ...a, w, h} def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
rp:*T = *T~~r0 rp:*T = *T~~r0
xp:*T = *T~~x0 xp:*T = *T~~x0
if (hasarch{'X86_64'} and w>=k and h>=k) { if (hasarch{'X86_64'} and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h} transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
} else { } else {
if (h==2) @for (x0 in xp, x1 in xp+w over i to w) { store{rp, i*2, x0}; store{rp, i*2+1, x1} } if (h==2 and h==hs) @for (x0 in xp, x1 in xp+ws over i to w) { store{rp, i*2, x0}; store{rp, i*2+1, x1} }
else if (w==2) @for (r0 in rp, r1 in rp+h over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} } else if (w==2 and w==ws) @for (r0 in rp, r1 in rp+hs over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} }
else call_base{rp, xp, w, h} else call_base{rp, xp, w, h}
} }
} }