From 814a6776764cfaab2a58e8c347fdb7043e440b88 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Tue, 28 Mar 2023 21:52:13 -0400 Subject: [PATCH] Make transpose movement functions strided --- src/builtins/transpose.c | 29 ++++++++----------- src/singeli/src/transpose.singeli | 46 +++++++++++++++---------------- 2 files changed, 35 insertions(+), 40 deletions(-) diff --git a/src/builtins/transpose.c b/src/builtins/transpose.c index 8bb00681..900b08a7 100644 --- a/src/builtins/transpose.c +++ b/src/builtins/transpose.c @@ -32,25 +32,20 @@ #endif #endif -typedef void (*TranspFn)(void*,void*,u64,u64); +#define DECL_BASE(T) \ + static NOINLINE void transpose_##T(void* rv, void* xv, u64 bw, u64 bh, u64 w, u64 h) { \ + T* rp=rv; T* xp=xv; \ + PLAINLOOP for(usz y=0;y 2*k or h&(line_elts-1) != 0) { + if (line_elts > 2*k or h&(line_elts-1) != 0 or h != hs) { ho := h%k # Effective height, like we for w he := h; if (use_overlap{ho}) he += k - ho @@ -88,7 +88,7 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = { # Main transpose @for_mult_max{kh, h-kh} (y to he) { @for_mult_max{k, wm} (x to we) { - kernel{...at{x,y}, k, kh, w, h} + kernel{...at{x,y}, k, kh, ws, hs} } } # Half-row(s) for non-square i16 case @@ -98,12 +98,12 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = { @for (yi to n) { y:u64 = 0; if (yi == n-1) y = h - e @for_mult_max{k, wm} (x to we) { - kernel{...at{x,y}, k, k, w, h} + kernel{...at{x,y}, k, k, ws, hs} } } } # Base transpose used if overlap wasn't - if (ho!=0 and he==h) { hs := h-ho; call_base{rp+hs, xp+w*hs, w, ho} } + if (ho!=0 and he==h) { hd := h-ho; call_base{rp+hd, xp+ws*hd, w, ho} } } else { # Result rows are aligned with each other so it's possible to # write a full cache line at a time @@ -118,57 +118,57 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h} = { each{tup, ...each{vt, iota{line_vecs}}} } ro := tail{6, -u64~~rp} / (width{T}/8) # Offset to align within cache line; assume elt-aligned - wh := w*h + wh := ws*h yn := h if (ro != 0) { ra := line_elts - ro y := h - ra - rpe := rp + y + (w-1)*h # Cache aligned + rpe := rp + y + (w-1)*hs # Cache aligned # Part of first and last result row aren't covered by the split loop - def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, w*i}} + def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, ws*i}} trtail{rp, xp, ro} - trtail{rpe, xp + y*w + w-1, ra} + trtail{rpe, xp + y*ws + w-1, ra} # Transpose first few rows and last few rows together @for_mult_max{k, wm} (x to we) { {xpo,rpo} := at{x, y} - o := w*y + x + o := ws*y + x def loadx{_} = { l:=load{*VT~~(xp+o)} - o+=w; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1 + o+=ws; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1 l } def rls = get_lines{loadx} # 4 rows of 2 vectors each - each{{i,v} => {p:=rpo+i*h; if (i<3 or p {p:=rpo+i*hs; if (i<3 or p load{*VT~~(xpo+i*w), 0}} - each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls} + def rls = get_lines{{i} => load{*VT~~(xpo+i*ws), 0}} + each{{i,v} => store_line{*VT~~(rpo+i*hs), v}, iota{k}, rls} } } } - if (we==w) @for(ws from w-wo to w) { - xpo:=xp+ws; rpo:=rp+h*ws - @for (i to h) store{rpo, i, load{xpo, w*i}} + if (we==w) @for(wd from w-wo to w) { + xpo:=xp+wd; rpo:=rp+hs*wd + @for (i to h) store{rpo, i, load{xpo, ws*i}} } } -fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = { +fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = { # Scalar transpose defined in C def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' - def call_base{...a} = emit{void, merge{'base_transpose_',ts}, ...a, w, h} + def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs} rp:*T = *T~~r0 xp:*T = *T~~x0 if (hasarch{'X86_64'} and w>=k and h>=k) { - transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h} + transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs} } else { - if (h==2) @for (x0 in xp, x1 in xp+w over i to w) { store{rp, i*2, x0}; store{rp, i*2+1, x1} } - else if (w==2) @for (r0 in rp, r1 in rp+h over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} } + if (h==2 and h==hs) @for (x0 in xp, x1 in xp+ws over i to w) { store{rp, i*2, x0}; store{rp, i*2+1, x1} } + else if (w==2 and w==ws) @for (r0 in rp, r1 in rp+hs over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} } else call_base{rp, xp, w, h} } }