Kernel transpose with overreading and skipped writes, for short rows

This commit is contained in:
Marshall Lochbaum 2024-10-31 19:16:50 -04:00
parent b6e418ed5b
commit 780bfdfa0b

View File

@ -47,19 +47,46 @@ def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
{ 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}} { 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}}
{128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}} {128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
} }
def store1of2{a:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
{ 64} => storeLow{*u64~~a, 64, [2]u64~~v}
{128} => store{a, 0, T~~half{v,0}}
}
def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}} def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x} def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}} def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x} def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
# Transpose kernel of size kw,kh in size w,h array # Transpose kernel of size kw,kh in size w,h array
def kernel{src:*T, dst:*T, kw, kh, w, h} = { def kernel_part{part_w}{src:*T, dst:*T, kw, kh, w, h} = {
def n = (kw*kh*width{T}) / arch_defvw # Number of vectors def n = (kw*kh*width{T}) / arch_defvw # Number of vectors
def xvs = load_k{[kw]T, src, n, w} def xvs = load_k{[kw]T, src, n, w}
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square
store_k{[kh]T, dst, rvs, n, h} def stores = store_k{[kh]T, ..., h}
if (same{part_w, 0}) {
stores{dst, rvs, n}
} else {
# Write w results, kw/2 <= n < kw
d := dst
def vd = kw / n # Number of writes for each output vector (1 or 2)
def store_slice{rv, len} = {
stores{d, slice{rv,0,len}, len}
d += len*vd*h
}
store_slice{rvs, n/2} # Unconditionally store first half
rt := slice{rvs,n/2} # Remaining tail
def wtail{b} = {
if ((part_w & (vd*b)) != 0) {
store_slice{rt, b}
slice{rt,0,b} = slice{rt,b,2*b}
}
if (b>1) wtail{b/2}
}
wtail{n/4}
if (vd>1 and (part_w & 1) != 0) store1of2{*[kh]T~~d, select{rt,0}}
}
} }
def kernel = kernel_part{0}
def for_mult{k}{vars,begin,end,iter} = { def for_mult{k}{vars,begin,end,iter} = {
@ -220,9 +247,15 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
rp:*T = *T~~r0 rp:*T = *T~~r0
xp:*T = *T~~x0 xp:*T = *T~~x0
def vl = arch_defvw / width{T} def vl = arch_defvw / width{T}
if (has_simd and k>max{4,8/width{T}} and h>=kh and w>=k/2 and w<k) {
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
}
return{}
}
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and T==i8 and h>=vl and w==ws) { if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and T==i8 and h>=vl and w==ws) {
def tr = transpose_with_modular{rp, xp, ., h, hs} def tr = transpose_with_modular{rp, xp, ., h, hs}
each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{7}} # 3 to 15 each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{3}} # 3 to 7
} }
if (has_simd and k!=0 and w>=k and h>=k) { if (has_simd and k!=0 and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs} transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}