Extend kernel transpose from AVX2 only to SSE2, NEON
This commit is contained in:
parent
6e4859e17f
commit
5bdee886cc
@ -9,10 +9,10 @@
|
||||
// SHOULD have bit matrix transpose kernel
|
||||
// CPU sizes: native or SIMD code
|
||||
// Large SIMD kernels used when they fit, overlapping for odd sizes
|
||||
// i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
|
||||
// SSE, NEON i8: 8×8 ; i16: 8×8; i32: 4×4; f64: scalar
|
||||
// AVX i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
|
||||
// COULD use half-width or smaller kernels to improve odd sizes
|
||||
// Scalar transpose or loop used for overhang of 1
|
||||
// SHOULD add NEON
|
||||
|
||||
// Reorder Axes
|
||||
// If 𝕨 indicates the identity permutation, return 𝕩
|
||||
|
||||
@ -4,6 +4,8 @@ include './f64'
|
||||
include './mask'
|
||||
include './bitops'
|
||||
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
|
||||
# Group l (power of 2) elements into paired groups of length o
|
||||
# e.g. pairs{2, iota{8}} = {{0,1,4,5}, {2,3,6,7}}
|
||||
def pairs{o, x} = {
|
||||
@ -22,30 +24,40 @@ def permute_pass{o, x} = {
|
||||
merge{h{0,2}, h{1,3}}
|
||||
}
|
||||
def unpack_to{f, l, x} = {
|
||||
def pass = if (f) permute_pass else unpack_pass
|
||||
def pass = if (avx2 and f) permute_pass else unpack_pass
|
||||
pass{l, if (l==1) x else unpack_to{0, l/2, x}}
|
||||
}
|
||||
# Last pass for square kernel packed in halves
|
||||
def shuf_pass{x} = each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
|
||||
def halved_pass{n, x} = {
|
||||
if (not avx2) unpack_pass{n/2, x}
|
||||
else each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
|
||||
}
|
||||
|
||||
# Square kernel where width is a full vector
|
||||
def transpose_square{VT, l, x if hasarch{'AVX2'}} = unpack_to{1, l/2, x}
|
||||
def transpose_square{VT, l, x if avx2} = unpack_to{1, l/2, x}
|
||||
|
||||
def load2{a:*T, b:*T} = pair{load{a}, load{b}}
|
||||
def store2{a:*T, b:*T, v:T2 if w128i{T} and w256{T2}} = {
|
||||
each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
||||
def load2{a:*T, b:*T} = match (width{T}) {
|
||||
{64} => {
|
||||
def v = each{{p}=>loadLow{*[2]u64~~p, 64}, tup{a,b}}
|
||||
n_d{T}~~zip{...v, 0}
|
||||
}
|
||||
{128} => pair{load{a}, load{b}}
|
||||
}
|
||||
def load_k {VT, src, l, w if w256{VT}} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if w256{VT}} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
||||
def load_k {VT, src, l, w if w128{VT}} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if w128{VT}} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||
def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||
{ 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}}
|
||||
{128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
||||
}
|
||||
def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||
def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
||||
def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||
|
||||
# Transpose kernel of size kw,kh in size w,h array
|
||||
def kernel{src:*T, dst:*T, kw, kh, w, h} = {
|
||||
def n = (kw*kh*width{T}) / 256 # Number of vectors
|
||||
def n = (kw*kh*width{T}) / arch_defvw # Number of vectors
|
||||
def xvs = load_k{[kw]T, src, n, w}
|
||||
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
|
||||
def rvs = if (n==kw) xt else shuf_pass{xt} # To kh by kh for packed square
|
||||
def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square
|
||||
store_k{[kh]T, dst, rvs, n, h}
|
||||
}
|
||||
|
||||
@ -161,14 +173,14 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
||||
}
|
||||
}
|
||||
|
||||
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||
# Scalar transpose defined in C
|
||||
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
||||
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
|
||||
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
if (hasarch{'AVX2'} and w>=k and h>=k) {
|
||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||
} else {
|
||||
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)
|
||||
@ -177,13 +189,12 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void
|
||||
}
|
||||
}
|
||||
|
||||
def transpose{T, k} = transpose{T, k, k}
|
||||
def transpose{T, k if knum{k}} = transpose{T, tup{k, k}}
|
||||
|
||||
exportT{'simd_transpose', tup{
|
||||
transpose{i8 , 16},
|
||||
transpose{i16, 8, 16},
|
||||
transpose{i32, 8},
|
||||
transpose{i64, 4}
|
||||
}}
|
||||
def tr_types = tup{i8, i16, i32, i64}
|
||||
def tr_kernels = if (not avx2) tup{ 8, 8, 4, 0 }
|
||||
else tup{16, tup{8, 16}, 8, 4 }
|
||||
|
||||
exportT{'interleave_fns', each{interleave, tup{i8, i16, i32, i64}}}
|
||||
exportT{'simd_transpose', each{transpose, tr_types, tr_kernels}}
|
||||
|
||||
exportT{'interleave_fns', each{interleave, tr_types}}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user