Width-4 transpose with shuffles

This commit is contained in:
Marshall Lochbaum 2024-11-02 11:36:10 -04:00
parent 8bc37098dc
commit 0ab2e485ca

View File

@ -239,10 +239,28 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
}
each{perm_store, xs}
}
def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
def h = arch_defvw / 8
def V = [h]u8
def ih = iota{h}%16; def e = width{T} / 8
# Permutation to unzip by 4 within each lane
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
# Unzipping code for the resulting 4-byte units
def {st, proc, zipx} = match (h) {
{16} => tup{2, {x} => [4]f32~~x, {xs,i} => V~~zip128{...xs,i}}
{32} => tup{1, shuf{[4]u64, ., 0,2,1,3}, {xs,i} => shuf{[4]f32, xs, i + tup{0,2,0,2}}}
}
def xsp = each{load{*V~~xp, .}, iota{2}}
xs := each{proc, each{shuf{16, ., uz}, xsp}}
@unroll (i to 2) {
def U = [h/2]u8
rp := rp0 + st*i*rst
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
}
}
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
def odd_part{w} = if (w%2) w else odd_part{w/2}
def w = odd_part{wk}; def p = wk/w
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
def p = if (wk%2) 1 else 2; def w = wk/p
def vl = arch_defvw / (p*width{T})
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
modular_kernel{w,p}{rp+i, xp+i*wk, hs}
@ -257,19 +275,18 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
rp:*T = *T~~r0
xp:*T = *T~~x0
def vl = arch_defvw / width{T}
def use_kpart = k>max{4,8/width{T}}
def use_kpart = width{T}<=16 and k>max{4,8/width{T}}
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
if (has_blend and h>=vl/2 and w==ws and (w%2==0 or h>=vl)) {
def tr = transpose_fixed_width{rp, xp, ., h, hs}
def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}}
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
}
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
}
return{}
}
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) {
def tr = transpose_with_modular{rp, xp, ., h, hs}
def ws = replicate{{i} => (not use_kpart) or i < k/2, tup{3,5,6,7}}
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
}
if (has_simd and k!=0 and w>=k and h>=k) {
} else if (has_simd and k!=0 and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
} else {
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)