Width-4 transpose with shuffles
This commit is contained in:
parent
8bc37098dc
commit
0ab2e485ca
@ -239,10 +239,28 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||
}
|
||||
each{perm_store, xs}
|
||||
}
|
||||
def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def V = [h]u8
|
||||
def ih = iota{h}%16; def e = width{T} / 8
|
||||
# Permutation to unzip by 4 within each lane
|
||||
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
|
||||
# Unzipping code for the resulting 4-byte units
|
||||
def {st, proc, zipx} = match (h) {
|
||||
{16} => tup{2, {x} => [4]f32~~x, {xs,i} => V~~zip128{...xs,i}}
|
||||
{32} => tup{1, shuf{[4]u64, ., 0,2,1,3}, {xs,i} => shuf{[4]f32, xs, i + tup{0,2,0,2}}}
|
||||
}
|
||||
def xsp = each{load{*V~~xp, .}, iota{2}}
|
||||
xs := each{proc, each{shuf{16, ., uz}, xsp}}
|
||||
@unroll (i to 2) {
|
||||
def U = [h/2]u8
|
||||
rp := rp0 + st*i*rst
|
||||
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
|
||||
}
|
||||
}
|
||||
|
||||
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
|
||||
def odd_part{w} = if (w%2) w else odd_part{w/2}
|
||||
def w = odd_part{wk}; def p = wk/w
|
||||
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
||||
def p = if (wk%2) 1 else 2; def w = wk/p
|
||||
def vl = arch_defvw / (p*width{T})
|
||||
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
|
||||
modular_kernel{w,p}{rp+i, xp+i*wk, hs}
|
||||
@ -257,19 +275,18 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
def vl = arch_defvw / width{T}
|
||||
def use_kpart = k>max{4,8/width{T}}
|
||||
def use_kpart = width{T}<=16 and k>max{4,8/width{T}}
|
||||
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
|
||||
if (has_blend and h>=vl/2 and w==ws and (w%2==0 or h>=vl)) {
|
||||
def tr = transpose_fixed_width{rp, xp, ., h, hs}
|
||||
def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}}
|
||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
||||
}
|
||||
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
|
||||
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
||||
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
||||
}
|
||||
return{}
|
||||
}
|
||||
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) {
|
||||
def tr = transpose_with_modular{rp, xp, ., h, hs}
|
||||
def ws = replicate{{i} => (not use_kpart) or i < k/2, tup{3,5,6,7}}
|
||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
||||
}
|
||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
} else if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||
} else {
|
||||
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user