More transpose kernel refactoring
This commit is contained in:
parent
c0aaa6f615
commit
f6c6e72661
@ -19,20 +19,21 @@ def pairs{o, x} = {
|
||||
tupsel{tup{g, g+o}, x}
|
||||
}
|
||||
def unpack_pass{o, x} = merge{...each{unpackQ, ...pairs{o, x}}}
|
||||
|
||||
def vtranspose{x & tuplen{x}==8 & type{tupsel{0,x}}==[8]i32 & hasarch{'X86_64'}} = {
|
||||
def t1 = unpack_pass{1, x}
|
||||
def t2 = unpack_pass{2, t1}
|
||||
def t2pairs = pairs{4, t2}
|
||||
def h{p} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,p}, ...t2pairs}
|
||||
def permute_pass{o, x} = {
|
||||
def p = pairs{o, x}
|
||||
def h{s} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,s}, ...p}
|
||||
merge{h{16b20}, h{16b31}}
|
||||
}
|
||||
|
||||
def vtranspose{x & tuplen{x}==4 & type{tupsel{0,x}}==[4]i64 & hasarch{'X86_64'}} = {
|
||||
def t1 = unpack_pass{1, x}
|
||||
def t2pairs = pairs{2, t1}
|
||||
def h{p} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,p}, ...t2pairs}
|
||||
merge{h{16b20}, h{16b31}}
|
||||
def ktest{a,l,T}{x} = {
|
||||
if (hasarch{a} and tuplen{x}==l and type{tupsel{0,x}}==T) 1 else 0
|
||||
}
|
||||
|
||||
def vtranspose{x & ktest{'X86_64',8,[8]i32}{x}} = {
|
||||
permute_pass{4, unpack_pass{2, unpack_pass{1, x}}}
|
||||
}
|
||||
def vtranspose{x & ktest{'X86_64',4,[4]i64}{x}} = {
|
||||
permute_pass{2, unpack_pass{1, x}}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user