Combine code for the different transpose kernel geometries
This commit is contained in:
parent
e40feaa81c
commit
9673771187
@ -19,32 +19,20 @@ def pairs{o, x} = {
|
||||
tupsel{tup{g, g+o}, x}
|
||||
}
|
||||
def unpack_pass{o, x} = merge{...each{unpackQ, ...pairs{o, x}}}
|
||||
def unpack_to{l, x} = unpack_pass{l, if (l==1) x else unpack_to{l/2, x}}
|
||||
def permute_pass{o, x} = {
|
||||
def p = pairs{o, x}
|
||||
def h{s} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,s}, ...p}
|
||||
merge{h{16b20}, h{16b31}}
|
||||
}
|
||||
def unpack_to{f, l, x} = {
|
||||
def pass = if (f) permute_pass else unpack_pass
|
||||
pass{l, if (l==1) x else unpack_to{0, l/2, x}}
|
||||
}
|
||||
# Last pass for square kernel packed in halves
|
||||
def shuf_pass{x} = each{{v} => shuf{[4]i64, v, 4b3120}, x}
|
||||
|
||||
# Square kernel where width is a full vector
|
||||
def transpose_square{x & hasarch{'X86_64'}} = {
|
||||
def k = tuplen{x}
|
||||
def T = type{tupsel{0,x}}; assert{256==width{T}}; assert{k==vcount{T}}
|
||||
def rvs = permute_pass{k/2, unpack_to{k/4, x}}
|
||||
}
|
||||
# Square kernel where width is half a vector; top half next to bottom
|
||||
def transpose_square_halves{x & hasarch{'X86_64'}} = {
|
||||
def l = tuplen{x}
|
||||
def T = type{tupsel{0,x}}; assert{256==width{T}}; assert{l==vcount{T}/4}
|
||||
def r = unpack_to{l/2, x}
|
||||
each{{v} => T~~shuf{[4]i64, v, 4b3120}, r}
|
||||
}
|
||||
# Same with 2*k by k rectangle
|
||||
def transpose_rect_halves{x & hasarch{'X86_64'}} = {
|
||||
def l = tuplen{x}
|
||||
def T = type{tupsel{0,x}}; assert{256==width{T}}; assert{l==vcount{T}/2}
|
||||
each{bind{~~,T}, unpack_to{l/2, x}}
|
||||
}
|
||||
def transpose_square{VT, l, x & hasarch{'X86_64'}} = unpack_to{1, l/2, x}
|
||||
|
||||
def load2{a:T, b:T & w128i{eltype{T}}} = {
|
||||
def V = eltype{T}
|
||||
@ -53,26 +41,19 @@ def load2{a:T, b:T & w128i{eltype{T}}} = {
|
||||
def store2{a:T, b:T, v:T2 & w128i{eltype{T}} & w256{T2}} = {
|
||||
each{{p, i} => store{p, 0, half{v, i}}, tup{a,b}, iota{2}}
|
||||
}
|
||||
def load_k {VT, src, l, w & w256{VT}} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||
def store_k{VT, dst, x, l, h & w256{VT}} = each{{i,v}=>store{*VT~~(dst+i*h), 0, v}, iota{l}, x}
|
||||
def load_k {VT, src, l, w & w128{VT}} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||
def store_k{VT, dst, x, l, h & w128{VT}} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||
|
||||
def kernel{src:P, dst:P, k, k, w, h & k*width{eltype{P}}==256} = {
|
||||
def VT = [k](eltype{P})
|
||||
def xvs = each{{i}=>load{*VT~~(src+i*w), 0}, iota{k}}
|
||||
def rvs = transpose_square{xvs}
|
||||
each{{i,v}=>store{*VT~~(dst+i*h), 0, v}, iota{k}, rvs}
|
||||
}
|
||||
def kernel{src:P, dst:P, k, k, w, h & k*width{eltype{P}}==128} = {
|
||||
def VT = [k](eltype{P})
|
||||
def s = k/2
|
||||
def xvs = each{{i}=>{p:=src+i*w; load2{*VT~~p, *VT~~(p+s*w)}}, iota{s}}
|
||||
def rvs = transpose_square_halves{xvs}
|
||||
each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+h), v}}, iota{s}, rvs}
|
||||
}
|
||||
def kernel{src:P, dst:P, k, d, w, h & d==2*k & d*width{eltype{P}}==256} = {
|
||||
def HT = [k](eltype{P})
|
||||
def VT = [d](eltype{P})
|
||||
def xvs = each{{i}=>{p:=src+i*w; load2{*HT~~p, *HT~~(p+k*w)}}, iota{k}}
|
||||
def rvs = transpose_rect_halves{xvs}
|
||||
each{{i,v}=>store{*VT~~(dst+i*h), 0, v}, iota{k}, rvs}
|
||||
# Transpose kernel of size kw,kh in size w,h array
|
||||
def kernel{src:P, dst:P, kw, kh, w, h} = {
|
||||
def T = eltype{P}
|
||||
def n = (kw*kh*width{T}) / 256 # Number of vectors
|
||||
def xvs = load_k{[kw]T, src, n, w}
|
||||
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
|
||||
def rvs = if (n==kw) xt else shuf_pass{xt} # To kh by kh for packed square
|
||||
store_k{[kh]T, dst, rvs, n, h}
|
||||
}
|
||||
|
||||
|
||||
@ -126,7 +107,7 @@ fn transpose{T, k}(r0:*void, x0:*void, w:u64, h:u64) : void = {
|
||||
# at multiples of 256 or so, but it's faster whenever it applies
|
||||
def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs}
|
||||
def get_lines{loadx} = {
|
||||
def vt{i} = transpose_square{each{loadx, k*i + iota{k}}}
|
||||
def vt{i} = transpose_square{VT, k, each{loadx, k*i + iota{k}}}
|
||||
each{tup, ...each{vt, iota{line_vecs}}}
|
||||
}
|
||||
ro := tail{6, -u64~~r0} / (width{T}/8) # Offset to align within cache line; assume elt-aligned
|
||||
|
||||
Loading…
Reference in New Issue
Block a user