This commit is contained in:
Marshall Lochbaum 2023-03-21 22:04:51 -04:00
parent dd141add3f
commit fad7f3aa8b

View File

@ -62,34 +62,16 @@ def for_mult{k}{vars,begin,end,block} = {
@for (i to end/k) exec{k*i, vars, block} @for (i to end/k) exec{k*i, vars, block}
} }
def mat_at{rp,xp,w,h}{x,y} = tup{xp + y*w + x, rp + x*h + y} fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
# Scalar transpose defined in C
def transpose_kernels{kw, kh, rp, xp, w, h} = {
@for_mult{kh} (y to h) {
@for_mult{kw} (x to w) {
kernel{...mat_at{rp,xp,w,h}{x,y}, kw, kh, w, h}
}
}
}
# Scalar transpose defined in C
def call_base{T} = {
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
{...a} => emit{void, merge{'base_transpose_',ts}, ...a} def call_base{...a} = emit{void, merge{'base_transpose_',ts}, ...a}
}
def small_transpose_out{T, k, rp, xp, w, h} = {
if (w<k or h<k) { call_base{T}{rp, xp, w, h, w, h}; return{} }
}
def edge_transpose{T, k, rp, xp, w, h} = {
def tr{...a} = call_base{T}{...a, w, h}
wo := w%k; ws := w-wo; if (wo) tr{rp+h*ws, xp+ ws, wo, h }
ho := h%k; hs := h-ho; if (ho) tr{rp+ hs, xp+w*hs, ws, ho}
}
fn transpose{T, k, d}(r0:*void, x0:*void, w:u64, h:u64) : void = {
rp:*T = *T~~r0 rp:*T = *T~~r0
xp:*T = *T~~x0 xp:*T = *T~~x0
small_transpose_out{T, k, rp, xp, w, h} if (w<k or h<k) { call_base{rp, xp, w, h, w, h}; return{} }
def at{x,y} = tup{xp + y*w + x, rp + x*h + y}
# Cache line info # Cache line info
def line_bytes = 64 def line_bytes = 64
@ -97,11 +79,15 @@ fn transpose{T, k, d}(r0:*void, x0:*void, w:u64, h:u64) : void = {
if (line_elts > 2*k or h&(line_elts-1) != 0) { if (line_elts > 2*k or h&(line_elts-1) != 0) {
# Main transpose # Main transpose
transpose_kernels{k, d, rp, xp, w, h} @for_mult{kh} (y to h) {
# Extra column for uneven i16 case
if (2*k == d and (h & k) != 0) { y := h-h%d
@for_mult{k} (x to w) { @for_mult{k} (x to w) {
kernel{...mat_at{rp,xp,w,h}{x,y}, k, k, w, h} kernel{...at{x,y}, k, kh, w, h}
}
}
# Extra row for uneven i16 case
if (2*k == kh and (h & k) != 0) { y := h-h%kh
@for_mult{k} (x to w) {
kernel{...at{x,y}, k, k, w, h}
} }
} }
} else { } else {
@ -109,7 +95,7 @@ fn transpose{T, k, d}(r0:*void, x0:*void, w:u64, h:u64) : void = {
# write a full cache line at a time # write a full cache line at a time
# This case is here to mitigate cache associativity problems at # This case is here to mitigate cache associativity problems at
# at multiples of 256 or so, but it's faster whenever it applies # at multiples of 256 or so, but it's faster whenever it applies
assert{k == d} assert{k == kh}
def VT = [k]T def VT = [k]T
def line_vecs = line_bytes / (width{VT}/8) def line_vecs = line_bytes / (width{VT}/8)
def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs} def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs}
@ -144,14 +130,16 @@ fn transpose{T, k, d}(r0:*void, x0:*void, w:u64, h:u64) : void = {
} }
@for_mult{line_elts} (y0 to yn) { y := y0 + ro @for_mult{line_elts} (y0 to yn) { y := y0 + ro
@for_mult{k} (x to w) { @for_mult{k} (x to w) {
{xpo,rpo} := mat_at{rp,xp,w,h}{x, y} {xpo,rpo} := at{x, y}
def rls = get_lines{{i} => load{*VT~~(xpo+i*w), 0}} def rls = get_lines{{i} => load{*VT~~(xpo+i*w), 0}}
each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls} each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls}
} }
} }
} }
edge_transpose{T, k, rp, xp, w, h} def edge_tr{...a} = call_base{...a, w, h}
wo := w%k; ws := w-wo; if (wo) edge_tr{rp+h*ws, xp+ ws, wo, h }
ho := h%k; hs := h-ho; if (ho) edge_tr{rp+ hs, xp+w*hs, ws, ho}
} }
def transpose{T, k} = transpose{T, k, k} def transpose{T, k} = transpose{T, k, k}