Refactor vtranspose with more general code

This commit is contained in:
Marshall Lochbaum 2023-03-20 18:14:10 -04:00
parent ca44d41daa
commit a0e85db702

View File

@ -11,10 +11,21 @@ if (hasarch{'X86_64'}) {
include './mask'
include './bitops'
# Group l (power of 2) elements into paired groups of length o
# e.g. pairs{2, iota{8}} = {{0,1,4,5}, {2,3,6,7}}
def pairs{o, x} = {
def i = iota{tuplen{x}/2}
def g = 2*i - i%o
tupsel{tup{g, g+o}, x}
}
def unpack_pass{o, x} = merge{...each{unpackQ, ...pairs{o, x}}}
def vtranspose{x & tuplen{x}==8 & type{tupsel{0,x}}==[8]i32 & hasarch{'X86_64'}} = {
def t1 = merge{...each{{i} => unpackQ{tupsel{i*2,x}, tupsel{i*2+1,x}}, iota{4}}}
def t2 = merge{...each{{i} => unpackQ{tupsel{i, t1}, tupsel{i+2, t1}}, tup{0,1,4,5}}}
each{{i} => emit{[8]i32, '_mm256_permute2f128_si256', tupsel{i%4,t2}, tupsel{i%4+4,t2}, tern{i>=4,16b31,16b20}}, iota{8}}
def t1 = unpack_pass{1, x}
def t2 = unpack_pass{2, t1}
def t2pairs = pairs{4, t2}
def h{p} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,p}, ...t2pairs}
merge{h{16b20}, h{16b31}}
}