Modular kernel transpose for odd*2 (so, just 6)

This commit is contained in:
Marshall Lochbaum 2024-11-01 21:31:50 -04:00
parent 091f08c6cc
commit 8bc37098dc

View File

@ -200,8 +200,9 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
} }
} }
# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst # Transpose a contiguous kernel of width w*p from x to r with stride rst
def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = { def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
assert{w%2 == 1}; lb{p} # Odd times power of two
def h = arch_defvw / 8 def h = arch_defvw / 8
def ih = iota{h}; def iw = iota{w} def ih = iota{h}; def iw = iota{w}
def I = [h]u8; def V = I def I = [h]u8; def V = I
@ -220,14 +221,19 @@ def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = {
} }
# Modular permutation of rows, and write to result # Modular permutation of rows, and write to result
rp := rp0 rp := rp0
mp := make{I, ih%e + (ih - ih%e)%16*w%h}; mi := I**e mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}; mi := I**e
def perm = if (h==16) shuf else { def perm = if (h==16) shuf else {
c := I**16 c := I**16
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c} def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
{x, i} => cross{shuf{16, x, i}, i} def q = match (p) { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
if (p>1) mp = q{mp}
{x, i} => q{cross{shuf{16, x, i}, i}}
} }
def perm_store{x} = { def perm_store{x} = {
store{*V~~rp, 0, perm{x, mp}} match (p) {
{1} => store{*V~~rp, 0, perm{x, mp}}
{2} => { def U = [h/2]u8; store2{*U~~rp, *U~~(rp+w*rst), perm{x, mp}} }
}
rp += rst; mp += mi rp += rst; mp += mi
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
} }
@ -235,9 +241,11 @@ def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = {
} }
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = { def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
def vl = arch_defvw / width{T} def odd_part{w} = if (w%2) w else odd_part{w/2}
def w = odd_part{wk}; def p = wk/w
def vl = arch_defvw / (p*width{T})
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) { @for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
modular_kernel{wk}{rp+i, xp+i*wk, hs} modular_kernel{w,p}{rp+i, xp+i*wk, hs}
} }
} }
@ -258,7 +266,7 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
} }
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) { if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) {
def tr = transpose_with_modular{rp, xp, ., h, hs} def tr = transpose_with_modular{rp, xp, ., h, hs}
def ws = replicate{{i} => (not use_kpart) or i < k/2, 3+2*iota{3}} # 3 to 7 def ws = replicate{{i} => (not use_kpart) or i < k/2, tup{3,5,6,7}}
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws} each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
} }
if (has_simd and k!=0 and w>=k and h>=k) { if (has_simd and k!=0 and w>=k and h>=k) {