Extend modular permutation transpose to any element size
This commit is contained in:
parent
780bfdfa0b
commit
091f08c6cc
@ -201,16 +201,18 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst
|
# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst
|
||||||
def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = {
|
def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||||
|
def h = arch_defvw / 8
|
||||||
def ih = iota{h}; def iw = iota{w}
|
def ih = iota{h}; def iw = iota{w}
|
||||||
def I = [h]u8
|
def I = [h]u8; def V = I
|
||||||
|
def e = width{T} / 8
|
||||||
# Load a shape h,w slice of x, but consider as shape w,h
|
# Load a shape h,w slice of x, but consider as shape w,h
|
||||||
def xsp = each{load{*[h]T~~xp, .}, iw}
|
def xsp = each{load{*V~~xp, .}, iw}
|
||||||
# Modular permutation of (reshaped argument) columns
|
# Modular permutation of (reshaped argument) columns
|
||||||
xs := select{xsp, find_index{h*iw % w, iw}}
|
xs := select{xsp, find_index{h/e*iw % w, iw}}
|
||||||
# Rotate each column by its index
|
# Rotate each column by its index
|
||||||
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
|
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
|
||||||
def m = make{I, 0xff * ((ih % w) & k != 0)}
|
def m = make{I, 0xff * (((ih // e) % w) & k != 0)}
|
||||||
def bl{x,y} = { x = homBlend{x,y,m} }
|
def bl{x,y} = { x = homBlend{x,y,m} }
|
||||||
x0 := select{xs, 0}
|
x0 := select{xs, 0}
|
||||||
def xord = select{xs, -k*iw % w}
|
def xord = select{xs, -k*iw % w}
|
||||||
@ -218,14 +220,14 @@ def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = {
|
|||||||
}
|
}
|
||||||
# Modular permutation of rows, and write to result
|
# Modular permutation of rows, and write to result
|
||||||
rp := rp0
|
rp := rp0
|
||||||
mp := make{I, w*(ih%16) % h}; mi := I**1
|
mp := make{I, ih%e + (ih - ih%e)%16*w%h}; mi := I**e
|
||||||
def perm = if (h==16) shuf else {
|
def perm = if (h==16) shuf else {
|
||||||
c := I**16
|
c := I**16
|
||||||
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
||||||
{x, i} => cross{shuf{16, x, i}, i}
|
{x, i} => cross{shuf{16, x, i}, i}
|
||||||
}
|
}
|
||||||
def perm_store{x} = {
|
def perm_store{x} = {
|
||||||
store{*[h]T~~rp, 0, perm{x, mp}}
|
store{*V~~rp, 0, perm{x, mp}}
|
||||||
rp += rst; mp += mi
|
rp += rst; mp += mi
|
||||||
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
|
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
|
||||||
}
|
}
|
||||||
@ -235,7 +237,7 @@ def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = {
|
|||||||
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
|
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
|
||||||
def vl = arch_defvw / width{T}
|
def vl = arch_defvw / width{T}
|
||||||
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
|
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
|
||||||
modular_kernel{wk,vl}{rp+i, xp+i*wk, hs}
|
modular_kernel{wk}{rp+i, xp+i*wk, hs}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,15 +249,17 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
|
|||||||
rp:*T = *T~~r0
|
rp:*T = *T~~r0
|
||||||
xp:*T = *T~~x0
|
xp:*T = *T~~x0
|
||||||
def vl = arch_defvw / width{T}
|
def vl = arch_defvw / width{T}
|
||||||
if (has_simd and k>max{4,8/width{T}} and h>=kh and w>=k/2 and w<k) {
|
def use_kpart = k>max{4,8/width{T}}
|
||||||
|
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
|
||||||
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
||||||
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
||||||
}
|
}
|
||||||
return{}
|
return{}
|
||||||
}
|
}
|
||||||
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and T==i8 and h>=vl and w==ws) {
|
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) {
|
||||||
def tr = transpose_with_modular{rp, xp, ., h, hs}
|
def tr = transpose_with_modular{rp, xp, ., h, hs}
|
||||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{3}} # 3 to 7
|
def ws = replicate{{i} => (not use_kpart) or i < k/2, 3+2*iota{3}} # 3 to 7
|
||||||
|
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
||||||
}
|
}
|
||||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user