Modular kernel transpose for odd*2 (so, just 6)
This commit is contained in:
parent
091f08c6cc
commit
8bc37098dc
@ -200,8 +200,9 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst
|
# Transpose a contiguous kernel of width w*p from x to r with stride rst
|
||||||
def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = {
|
def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||||
|
assert{w%2 == 1}; lb{p} # Odd times power of two
|
||||||
def h = arch_defvw / 8
|
def h = arch_defvw / 8
|
||||||
def ih = iota{h}; def iw = iota{w}
|
def ih = iota{h}; def iw = iota{w}
|
||||||
def I = [h]u8; def V = I
|
def I = [h]u8; def V = I
|
||||||
@ -220,14 +221,19 @@ def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = {
|
|||||||
}
|
}
|
||||||
# Modular permutation of rows, and write to result
|
# Modular permutation of rows, and write to result
|
||||||
rp := rp0
|
rp := rp0
|
||||||
mp := make{I, ih%e + (ih - ih%e)%16*w%h}; mi := I**e
|
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}; mi := I**e
|
||||||
def perm = if (h==16) shuf else {
|
def perm = if (h==16) shuf else {
|
||||||
c := I**16
|
c := I**16
|
||||||
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
||||||
{x, i} => cross{shuf{16, x, i}, i}
|
def q = match (p) { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
|
||||||
|
if (p>1) mp = q{mp}
|
||||||
|
{x, i} => q{cross{shuf{16, x, i}, i}}
|
||||||
}
|
}
|
||||||
def perm_store{x} = {
|
def perm_store{x} = {
|
||||||
store{*V~~rp, 0, perm{x, mp}}
|
match (p) {
|
||||||
|
{1} => store{*V~~rp, 0, perm{x, mp}}
|
||||||
|
{2} => { def U = [h/2]u8; store2{*U~~rp, *U~~(rp+w*rst), perm{x, mp}} }
|
||||||
|
}
|
||||||
rp += rst; mp += mi
|
rp += rst; mp += mi
|
||||||
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
|
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
|
||||||
}
|
}
|
||||||
@ -235,9 +241,11 @@ def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
|
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
|
||||||
def vl = arch_defvw / width{T}
|
def odd_part{w} = if (w%2) w else odd_part{w/2}
|
||||||
|
def w = odd_part{wk}; def p = wk/w
|
||||||
|
def vl = arch_defvw / (p*width{T})
|
||||||
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
|
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
|
||||||
modular_kernel{wk}{rp+i, xp+i*wk, hs}
|
modular_kernel{w,p}{rp+i, xp+i*wk, hs}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +266,7 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
|
|||||||
}
|
}
|
||||||
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) {
|
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) {
|
||||||
def tr = transpose_with_modular{rp, xp, ., h, hs}
|
def tr = transpose_with_modular{rp, xp, ., h, hs}
|
||||||
def ws = replicate{{i} => (not use_kpart) or i < k/2, 3+2*iota{3}} # 3 to 7
|
def ws = replicate{{i} => (not use_kpart) or i < k/2, tup{3,5,6,7}}
|
||||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
||||||
}
|
}
|
||||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user