Make transpose kernel code type-generic

This commit is contained in:
Marshall Lochbaum 2023-03-21 19:17:00 -04:00
parent 923c485cc2
commit e40feaa81c

View File

@ -19,35 +19,31 @@ def pairs{o, x} = {
tupsel{tup{g, g+o}, x}
}
def unpack_pass{o, x} = merge{...each{unpackQ, ...pairs{o, x}}}
def unpack_to{l, x} = unpack_pass{l, if (l==1) x else unpack_to{l/2, x}}
def permute_pass{o, x} = {
def p = pairs{o, x}
def h{s} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,s}, ...p}
merge{h{16b20}, h{16b31}}
}
def ktest{a,l,T}{x} = {
if (hasarch{a} and tuplen{x}==l and type{tupsel{0,x}}==T) 1 else 0
# Square kernel where width is a full vector
def transpose_square{x & hasarch{'X86_64'}} = {
def k = tuplen{x}
def T = type{tupsel{0,x}}; assert{256==width{T}}; assert{k==vcount{T}}
def rvs = permute_pass{k/2, unpack_to{k/4, x}}
}
def vtranspose{x & ktest{'X86_64',8,[8]i32}{x}} = {
permute_pass{4, unpack_pass{2, unpack_pass{1, x}}}
# Square kernel where width is half a vector; top half next to bottom
def transpose_square_halves{x & hasarch{'X86_64'}} = {
def l = tuplen{x}
def T = type{tupsel{0,x}}; assert{256==width{T}}; assert{l==vcount{T}/4}
def r = unpack_to{l/2, x}
each{{v} => T~~shuf{[4]i64, v, 4b3120}, r}
}
def vtranspose{x & ktest{'X86_64',4,[4]i64}{x}} = {
permute_pass{2, unpack_pass{1, x}}
}
def vtranspose2{x & ktest{'X86_64',8,[16]i16}{x}} = {
def r = unpack_pass{4, unpack_pass{2, unpack_pass{1, x}}}
each{bind{~~,[16]i16}, r}
}
# Transpose square packed as halves
def vtranspose{x & ktest{'X86_64',8,[32]i8}{x}} = {
def r = unpack_pass{4, unpack_pass{2, unpack_pass{1, x}}}
each{{v}=>[32]i8~~shuf{[4]i64, v, 4b3120}, r}
}
def vtranspose{x & ktest{'X86_64',4,[16]i16}{x}} = {
def r = unpack_pass{2, unpack_pass{1, x}}
each{{v}=>[16]i16~~shuf{[4]i64, v, 4b3120}, r}
# Same with 2*k by k rectangle
def transpose_rect_halves{x & hasarch{'X86_64'}} = {
def l = tuplen{x}
def T = type{tupsel{0,x}}; assert{256==width{T}}; assert{l==vcount{T}/2}
each{bind{~~,T}, unpack_to{l/2, x}}
}
def load2{a:T, b:T & w128i{eltype{T}}} = {
@ -61,21 +57,21 @@ def store2{a:T, b:T, v:T2 & w128i{eltype{T}} & w256{T2}} = {
def kernel{src:P, dst:P, k, k, w, h & k*width{eltype{P}}==256} = {
def VT = [k](eltype{P})
def xvs = each{{i}=>load{*VT~~(src+i*w), 0}, iota{k}}
def rvs = vtranspose{xvs}
def rvs = transpose_square{xvs}
each{{i,v}=>store{*VT~~(dst+i*h), 0, v}, iota{k}, rvs}
}
def kernel{src:P, dst:P, k, k, w, h & k*width{eltype{P}}==128} = {
def VT = [k](eltype{P})
def s = k/2
def xvs = each{{i}=>{p:=src+i*w; load2{*VT~~p, *VT~~(p+s*w)}}, iota{s}}
def rvs = vtranspose{xvs}
def rvs = transpose_square_halves{xvs}
each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+h), v}}, iota{s}, rvs}
}
def kernel{src:P, dst:P, k, d, w, h & d==2*k & d*width{eltype{P}}==256} = {
def HT = [k](eltype{P})
def VT = [d](eltype{P})
def xvs = each{{i}=>{p:=src+i*w; load2{*HT~~p, *HT~~(p+k*w)}}, iota{k}}
def rvs = vtranspose2{xvs}
def rvs = transpose_rect_halves{xvs}
each{{i,v}=>store{*VT~~(dst+i*h), 0, v}, iota{k}, rvs}
}
@ -130,7 +126,7 @@ fn transpose{T, k}(r0:*void, x0:*void, w:u64, h:u64) : void = {
# at multiples of 256 or so, but it's faster whenever it applies
def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs}
def get_lines{loadx} = {
def vt{i} = vtranspose{each{loadx, k*i + iota{k}}}
def vt{i} = transpose_square{each{loadx, k*i + iota{k}}}
each{tup, ...each{vt, iota{line_vecs}}}
}
ro := tail{6, -u64~~r0} / (width{T}/8) # Offset to align within cache line; assume elt-aligned