Overlap SIMD transpose on width too

This commit is contained in:
Marshall Lochbaum 2023-03-22 15:13:18 -04:00
parent 61eefe0776
commit 5a2bc15f2a

View File

@ -61,6 +61,12 @@ def for_mult{k}{vars,begin,end,block} = {
assert{begin == 0} assert{begin == 0}
@for (i to end/k) exec{k*i, vars, block} @for (i to end/k) exec{k*i, vars, block}
} }
def for_mult_max{k, m}{vars,begin,end,block} = {
@for_mult{k} (i0 to end) {
i:=i0; if (i>m) i = m
exec{i, vars, block}
}
}
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = { fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
# Scalar transpose defined in C # Scalar transpose defined in C
@ -78,19 +84,21 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
def line_elts = line_bytes / (width{T}/8) def line_elts = line_bytes / (width{T}/8)
def use_overlap{o} = o >= 2 # For overlapped SIMD instead of scalar def use_overlap{o} = o >= 2 # For overlapped SIMD instead of scalar
wo := w%k
# Effective width: number of columns read, counting overlap twice
# Just use base transpose for short overhang; otherwise round up
we := w; if (use_overlap{wo}) we += k - wo
wm := w - k
# Handle uneven height (extra rows) here, but not uneven width
if (line_elts > 2*k or h&(line_elts-1) != 0) { if (line_elts > 2*k or h&(line_elts-1) != 0) {
ho := h%k ho := h%k
# Effective height: number of rows read, counting overlap twice # Effective height, like we for w
# Just use base transpose for short overhang; otherwise round up
he := h; if (use_overlap{ho}) he += k - ho he := h; if (use_overlap{ho}) he += k - ho
def has_half = 2*k == kh def has_half = 2*k == kh
if (has_half and he==kh and h<he) he = k # Skip main loop; caught with he<h tests later if (has_half and he==kh and h<he) he = k # Skip main loop; caught with he<h tests later
# Main transpose # Main transpose
hm := h - kh @for_mult_max{kh, h-kh} (y to he) {
@for_mult{kh} (y0 to he) { y:=y0; if (y>hm) y = hm @for_mult_max{k, wm} (x to we) {
@for_mult{k} (x to w) {
kernel{...at{x,y}, k, kh, w, h} kernel{...at{x,y}, k, kh, w, h}
} }
} }
@ -100,13 +108,13 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
e := h%kh; if (he<h or e<k) e = k e := h%kh; if (he<h or e<k) e = k
@for (yi to n) { @for (yi to n) {
y:u64 = 0; if (yi == n-1) y = h - e y:u64 = 0; if (yi == n-1) y = h - e
@for_mult{k} (x to w) { @for_mult_max{k, wm} (x to we) {
kernel{...at{x,y}, k, k, w, h} kernel{...at{x,y}, k, k, w, h}
} }
} }
} }
# Base transpose used if overlap wasn't # Base transpose used if overlap wasn't
if (ho!=0 and he==h) { hs := h-ho; call_base{rp+hs, xp+w*hs, w-w%k, ho} } if (ho!=0 and he==h) { hs := h-ho; call_base{rp+hs, xp+w*hs, w, ho} }
} else { } else {
# Result rows are aligned with each other so it's possible to # Result rows are aligned with each other so it's possible to
# write a full cache line at a time # write a full cache line at a time
@ -126,14 +134,14 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
if (ro != 0) { if (ro != 0) {
ra := line_elts - ro ra := line_elts - ro
y := h - ra y := h - ra
rpo := rp + y # Cache aligned rpe := rp + y + (w-1)*h # Cache aligned
rpe := rpo + (w-1)*h
# Part of first and last result row aren't covered by the split loop # Part of first and last result row aren't covered by the split loop
def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, w*i}} def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, w*i}}
trtail{rp, xp, ro} trtail{rp, xp, ro}
trtail{rpe, xp + y*w + w-1, ra} trtail{rpe, xp + y*w + w-1, ra}
# Transpose first few rows and last few rows together # Transpose first few rows and last few rows together
@for_mult{k} (x to w) { @for_mult_max{k, wm} (x to we) {
{xpo,rpo} := at{x, y}
o := w*y + x o := w*y + x
def loadx{_} = { def loadx{_} = {
l:=load{*VT~~(xp+o)} l:=load{*VT~~(xp+o)}
@ -141,12 +149,12 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
l l
} }
def rls = get_lines{loadx} # 4 rows of 2 vectors each def rls = get_lines{loadx} # 4 rows of 2 vectors each
each{{i,v} => {if (i<3 or rpo<rpe) store_line{*VT~~rpo, v}; rpo+=h}, iota{k}, rls} each{{i,v} => {p:=rpo+i*h; if (i<3 or p<rpe) store_line{*VT~~p, v}}, iota{k}, rls}
} }
--yn # One strip handled --yn # One strip handled
} }
@for_mult{line_elts} (y0 to yn) { y := y0 + ro @for_mult{line_elts} (y0 to yn) { y := y0 + ro
@for_mult{k} (x to w) { @for_mult_max{k, wm} (x to we) {
{xpo,rpo} := at{x, y} {xpo,rpo} := at{x, y}
def rls = get_lines{{i} => load{*VT~~(xpo+i*w), 0}} def rls = get_lines{{i} => load{*VT~~(xpo+i*w), 0}}
each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls} each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls}
@ -154,7 +162,7 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
} }
} }
wo := w%k; if (wo!=0) { ws := w-wo; call_base{rp+h*ws, xp+ws, wo, h } } if (wo!=0 and we==w) { ws := w-wo; call_base{rp+h*ws, xp+ws, wo, h } }
} }
def transpose{T, k} = transpose{T, k, k} def transpose{T, k} = transpose{T, k, k}