Kernel transpose with overreading and skipped writes, for short rows
This commit is contained in:
parent
b6e418ed5b
commit
780bfdfa0b
@ -47,19 +47,46 @@ def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||
{ 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}}
|
||||
{128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
||||
}
|
||||
def store1of2{a:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||
{ 64} => storeLow{*u64~~a, 64, [2]u64~~v}
|
||||
{128} => store{a, 0, T~~half{v,0}}
|
||||
}
|
||||
def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||
def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
||||
def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||
|
||||
# Transpose kernel of size kw,kh in size w,h array
|
||||
def kernel{src:*T, dst:*T, kw, kh, w, h} = {
|
||||
def kernel_part{part_w}{src:*T, dst:*T, kw, kh, w, h} = {
|
||||
def n = (kw*kh*width{T}) / arch_defvw # Number of vectors
|
||||
def xvs = load_k{[kw]T, src, n, w}
|
||||
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
|
||||
def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square
|
||||
store_k{[kh]T, dst, rvs, n, h}
|
||||
def stores = store_k{[kh]T, ..., h}
|
||||
if (same{part_w, 0}) {
|
||||
stores{dst, rvs, n}
|
||||
} else {
|
||||
# Write w results, kw/2 <= n < kw
|
||||
d := dst
|
||||
def vd = kw / n # Number of writes for each output vector (1 or 2)
|
||||
def store_slice{rv, len} = {
|
||||
stores{d, slice{rv,0,len}, len}
|
||||
d += len*vd*h
|
||||
}
|
||||
store_slice{rvs, n/2} # Unconditionally store first half
|
||||
rt := slice{rvs,n/2} # Remaining tail
|
||||
def wtail{b} = {
|
||||
if ((part_w & (vd*b)) != 0) {
|
||||
store_slice{rt, b}
|
||||
slice{rt,0,b} = slice{rt,b,2*b}
|
||||
}
|
||||
if (b>1) wtail{b/2}
|
||||
}
|
||||
wtail{n/4}
|
||||
if (vd>1 and (part_w & 1) != 0) store1of2{*[kh]T~~d, select{rt,0}}
|
||||
}
|
||||
}
|
||||
def kernel = kernel_part{0}
|
||||
|
||||
|
||||
def for_mult{k}{vars,begin,end,iter} = {
|
||||
@ -220,9 +247,15 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
def vl = arch_defvw / width{T}
|
||||
if (has_simd and k>max{4,8/width{T}} and h>=kh and w>=k/2 and w<k) {
|
||||
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
||||
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
||||
}
|
||||
return{}
|
||||
}
|
||||
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and T==i8 and h>=vl and w==ws) {
|
||||
def tr = transpose_with_modular{rp, xp, ., h, hs}
|
||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{7}} # 3 to 15
|
||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{3}} # 3 to 7
|
||||
}
|
||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user