Fixed-height transposes, more or less reverse of fixed-width
This commit is contained in:
parent
0ab2e485ca
commit
abb96cb18a
@ -200,9 +200,25 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Utilities for kernels based on modular permutation
|
||||||
|
def rotcol{xs, mg:I} = {
|
||||||
|
def w = length{xs}
|
||||||
|
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
|
||||||
|
def vk = I**k; def m = (mg & vk) == vk
|
||||||
|
def bl{x,y} = { x = homBlend{x,y,m} }
|
||||||
|
x0 := select{xs, 0}
|
||||||
|
def xord = select{xs, k*iota{w} % w}
|
||||||
|
each{bl, xord, shiftleft{xord, tup{x0}}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
def get_modperm_lane_shuf{c} = {
|
||||||
|
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
||||||
|
{x, i} => cross{shuf{16, x, i}, i}
|
||||||
|
}
|
||||||
|
def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
|
||||||
|
|
||||||
# Transpose a contiguous kernel of width w*p from x to r with stride rst
|
# Transpose a contiguous kernel of width w*p from x to r with stride rst
|
||||||
def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
|
def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||||
assert{w%2 == 1}; lb{p} # Odd times power of two
|
|
||||||
def h = arch_defvw / 8
|
def h = arch_defvw / 8
|
||||||
def ih = iota{h}; def iw = iota{w}
|
def ih = iota{h}; def iw = iota{w}
|
||||||
def I = [h]u8; def V = I
|
def I = [h]u8; def V = I
|
||||||
@ -212,22 +228,15 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
|
|||||||
# Modular permutation of (reshaped argument) columns
|
# Modular permutation of (reshaped argument) columns
|
||||||
xs := select{xsp, find_index{h/e*iw % w, iw}}
|
xs := select{xsp, find_index{h/e*iw % w, iw}}
|
||||||
# Rotate each column by its index
|
# Rotate each column by its index
|
||||||
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
|
rotcol{reverse{xs}, make{I, (ih // e) % w}}
|
||||||
def m = make{I, 0xff * (((ih // e) % w) & k != 0)}
|
|
||||||
def bl{x,y} = { x = homBlend{x,y,m} }
|
|
||||||
x0 := select{xs, 0}
|
|
||||||
def xord = select{xs, -k*iw % w}
|
|
||||||
each{bl, xord, shiftleft{xord, tup{x0}}}
|
|
||||||
}
|
|
||||||
# Modular permutation of rows, and write to result
|
# Modular permutation of rows, and write to result
|
||||||
rp := rp0
|
rp := rp0
|
||||||
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}; mi := I**e
|
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}
|
||||||
|
mi := I**e
|
||||||
def perm = if (h==16) shuf else {
|
def perm = if (h==16) shuf else {
|
||||||
c := I**16
|
def sh = get_modperm_lane_shuf{I**16}
|
||||||
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
def q = tr_quads{p}; if (p>1) mp = q{mp}
|
||||||
def q = match (p) { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
|
{x, i} => q{sh{x, i}}
|
||||||
if (p>1) mp = q{mp}
|
|
||||||
{x, i} => q{cross{shuf{16, x, i}, i}}
|
|
||||||
}
|
}
|
||||||
def perm_store{x} = {
|
def perm_store{x} = {
|
||||||
match (p) {
|
match (p) {
|
||||||
@ -241,7 +250,7 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
|
|||||||
}
|
}
|
||||||
def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
|
def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||||
def h = arch_defvw / 8
|
def h = arch_defvw / 8
|
||||||
def V = [h]u8
|
def V = [h]u8; def U = n_h{V}
|
||||||
def ih = iota{h}%16; def e = width{T} / 8
|
def ih = iota{h}%16; def e = width{T} / 8
|
||||||
# Permutation to unzip by 4 within each lane
|
# Permutation to unzip by 4 within each lane
|
||||||
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
|
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
|
||||||
@ -253,12 +262,10 @@ def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
|
|||||||
def xsp = each{load{*V~~xp, .}, iota{2}}
|
def xsp = each{load{*V~~xp, .}, iota{2}}
|
||||||
xs := each{proc, each{shuf{16, ., uz}, xsp}}
|
xs := each{proc, each{shuf{16, ., uz}, xsp}}
|
||||||
@unroll (i to 2) {
|
@unroll (i to 2) {
|
||||||
def U = [h/2]u8
|
|
||||||
rp := rp0 + st*i*rst
|
rp := rp0 + st*i*rst
|
||||||
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
|
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
||||||
def p = if (wk%2) 1 else 2; def w = wk/p
|
def p = if (wk%2) 1 else 2; def w = wk/p
|
||||||
def vl = arch_defvw / (p*width{T})
|
def vl = arch_defvw / (p*width{T})
|
||||||
@ -267,6 +274,60 @@ def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Transpose a kernel of height w*p from x with stride xst to contiguous r
|
||||||
|
# w and h are named for the result, not argument, to match modular_kernel
|
||||||
|
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = {
|
||||||
|
def h = arch_defvw / 8
|
||||||
|
def ih = iota{h}; def iw = iota{w}
|
||||||
|
def I = [h]u8; def V = I
|
||||||
|
def e = width{T} / 8
|
||||||
|
# Read rows, modular permutation on each
|
||||||
|
def rotbit{x, l,m,h} = x%l + (x-x%l)*(h/m)%h + x//m*l
|
||||||
|
def wi = w + 2 * ((w-1) + (w&2)) # Inverse mod 32
|
||||||
|
def mpd = rotbit{ih%e + (ih - ih%e)%16*wi%h, e,e*p,h}
|
||||||
|
mp := make{I, if (h==16 or p==1) mpd else rotbit{mpd, 8,16,h}}
|
||||||
|
def rot_mp = {
|
||||||
|
def rot_lane = shuf{16, ., make{I, (ih-e)%16}}
|
||||||
|
def cross = if (h==16) ({x}=>x) else ^{make{I, 16*(ih%16<e)}, .}
|
||||||
|
{mp} => cross{rot_lane{mp}}
|
||||||
|
}
|
||||||
|
def perm = if (h==16) shuf else {
|
||||||
|
def sh = get_modperm_lane_shuf{I**16}
|
||||||
|
def q = tr_quads{p}
|
||||||
|
{x, i} => sh{q{x}, i}
|
||||||
|
}
|
||||||
|
xp := xp0
|
||||||
|
xs := @collect (w) {
|
||||||
|
x := match (p) {
|
||||||
|
{1} => perm{load{*V~~xp, 0}, mp}
|
||||||
|
{2} => { def U = [h/2]u8; perm{load2{*U~~xp, *U~~(xp+w*xst)}, mp} }
|
||||||
|
}
|
||||||
|
xp += xst; mp = rot_mp{mp}
|
||||||
|
x
|
||||||
|
}
|
||||||
|
# Rotate each column by its index
|
||||||
|
rotcol{xs, make{I, (ih // e) % w}}
|
||||||
|
# Permute vectors and store
|
||||||
|
each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}}
|
||||||
|
}
|
||||||
|
def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = {
|
||||||
|
def V = [arch_defvw / width{T}]T; def U = n_h{V}
|
||||||
|
xl := @unroll (i to 2) {
|
||||||
|
xp := xp0 + i*rst
|
||||||
|
x := load2{*U~~xp, *U~~(xp + 2*rst)}
|
||||||
|
if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}}
|
||||||
|
}
|
||||||
|
xs := unpack_typed{...unpack_typed{...xl}}
|
||||||
|
each{store{*V~~rp, ., .}, iota{2}, each{~~{V,.},xs}}
|
||||||
|
}
|
||||||
|
def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = {
|
||||||
|
def p = if (hk%2) 1 else 2; def h = hk/p
|
||||||
|
def vl = arch_defvw / (p*width{T})
|
||||||
|
@for_mult_max{vl, w-vl} (i to w+(-w)%vl) {
|
||||||
|
modular_kernel_rev{h,p}{rp+i*hk, xp+i, ws}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||||
# Scalar transpose defined in C
|
# Scalar transpose defined in C
|
||||||
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
||||||
@ -282,6 +343,11 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
|
|||||||
def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}}
|
def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}}
|
||||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
||||||
}
|
}
|
||||||
|
if (has_blend and w>=vl/2 and h==hs and (h%2==0 or w>=vl)) {
|
||||||
|
def tr = transpose_fixed_height{rp, xp, w, ws, .}
|
||||||
|
def hs = replicate{{i} => i!=(if (width{T}==64) 4 else kh), tup{3,4,5,6,7}}
|
||||||
|
each{{hk} => if (h==hk) { tr{hk}; return{} }, hs}
|
||||||
|
}
|
||||||
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
|
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
|
||||||
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
||||||
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user