Fixed-height transposes, more or less reverse of fixed-width

This commit is contained in:
Marshall Lochbaum 2024-11-03 20:20:31 -05:00
parent 0ab2e485ca
commit abb96cb18a

View File

@ -200,9 +200,25 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
} }
} }
# Utilities for kernels based on modular permutation
def rotcol{xs, mg:I} = {
def w = length{xs}
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
def vk = I**k; def m = (mg & vk) == vk
def bl{x,y} = { x = homBlend{x,y,m} }
x0 := select{xs, 0}
def xord = select{xs, k*iota{w} % w}
each{bl, xord, shiftleft{xord, tup{x0}}}
}
}
def get_modperm_lane_shuf{c} = {
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
{x, i} => cross{shuf{16, x, i}, i}
}
def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
# Transpose a contiguous kernel of width w*p from x to r with stride rst # Transpose a contiguous kernel of width w*p from x to r with stride rst
def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = { def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = {
assert{w%2 == 1}; lb{p} # Odd times power of two
def h = arch_defvw / 8 def h = arch_defvw / 8
def ih = iota{h}; def iw = iota{w} def ih = iota{h}; def iw = iota{w}
def I = [h]u8; def V = I def I = [h]u8; def V = I
@ -212,22 +228,15 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
# Modular permutation of (reshaped argument) columns # Modular permutation of (reshaped argument) columns
xs := select{xsp, find_index{h/e*iw % w, iw}} xs := select{xsp, find_index{h/e*iw % w, iw}}
# Rotate each column by its index # Rotate each column by its index
@unroll (kl to ceil_log2{w}) { def k = 1<<kl rotcol{reverse{xs}, make{I, (ih // e) % w}}
def m = make{I, 0xff * (((ih // e) % w) & k != 0)}
def bl{x,y} = { x = homBlend{x,y,m} }
x0 := select{xs, 0}
def xord = select{xs, -k*iw % w}
each{bl, xord, shiftleft{xord, tup{x0}}}
}
# Modular permutation of rows, and write to result # Modular permutation of rows, and write to result
rp := rp0 rp := rp0
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}; mi := I**e mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}
mi := I**e
def perm = if (h==16) shuf else { def perm = if (h==16) shuf else {
c := I**16 def sh = get_modperm_lane_shuf{I**16}
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c} def q = tr_quads{p}; if (p>1) mp = q{mp}
def q = match (p) { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} } {x, i} => q{sh{x, i}}
if (p>1) mp = q{mp}
{x, i} => q{cross{shuf{16, x, i}, i}}
} }
def perm_store{x} = { def perm_store{x} = {
match (p) { match (p) {
@ -241,7 +250,7 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = {
} }
def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = { def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
def h = arch_defvw / 8 def h = arch_defvw / 8
def V = [h]u8 def V = [h]u8; def U = n_h{V}
def ih = iota{h}%16; def e = width{T} / 8 def ih = iota{h}%16; def e = width{T} / 8
# Permutation to unzip by 4 within each lane # Permutation to unzip by 4 within each lane
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e} uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
@ -253,12 +262,10 @@ def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
def xsp = each{load{*V~~xp, .}, iota{2}} def xsp = each{load{*V~~xp, .}, iota{2}}
xs := each{proc, each{shuf{16, ., uz}, xsp}} xs := each{proc, each{shuf{16, ., uz}, xsp}}
@unroll (i to 2) { @unroll (i to 2) {
def U = [h/2]u8
rp := rp0 + st*i*rst rp := rp0 + st*i*rst
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}} store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
} }
} }
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = { def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
def p = if (wk%2) 1 else 2; def w = wk/p def p = if (wk%2) 1 else 2; def w = wk/p
def vl = arch_defvw / (p*width{T}) def vl = arch_defvw / (p*width{T})
@ -267,6 +274,60 @@ def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
} }
} }
# Transpose a kernel of height w*p from x with stride xst to contiguous r
# w and h are named for the result, not argument, to match modular_kernel
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = {
def h = arch_defvw / 8
def ih = iota{h}; def iw = iota{w}
def I = [h]u8; def V = I
def e = width{T} / 8
# Read rows, modular permutation on each
def rotbit{x, l,m,h} = x%l + (x-x%l)*(h/m)%h + x//m*l
def wi = w + 2 * ((w-1) + (w&2)) # Inverse mod 32
def mpd = rotbit{ih%e + (ih - ih%e)%16*wi%h, e,e*p,h}
mp := make{I, if (h==16 or p==1) mpd else rotbit{mpd, 8,16,h}}
def rot_mp = {
def rot_lane = shuf{16, ., make{I, (ih-e)%16}}
def cross = if (h==16) ({x}=>x) else ^{make{I, 16*(ih%16<e)}, .}
{mp} => cross{rot_lane{mp}}
}
def perm = if (h==16) shuf else {
def sh = get_modperm_lane_shuf{I**16}
def q = tr_quads{p}
{x, i} => sh{q{x}, i}
}
xp := xp0
xs := @collect (w) {
x := match (p) {
{1} => perm{load{*V~~xp, 0}, mp}
{2} => { def U = [h/2]u8; perm{load2{*U~~xp, *U~~(xp+w*xst)}, mp} }
}
xp += xst; mp = rot_mp{mp}
x
}
# Rotate each column by its index
rotcol{xs, make{I, (ih // e) % w}}
# Permute vectors and store
each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}}
}
def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = {
def V = [arch_defvw / width{T}]T; def U = n_h{V}
xl := @unroll (i to 2) {
xp := xp0 + i*rst
x := load2{*U~~xp, *U~~(xp + 2*rst)}
if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}}
}
xs := unpack_typed{...unpack_typed{...xl}}
each{store{*V~~rp, ., .}, iota{2}, each{~~{V,.},xs}}
}
def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = {
def p = if (hk%2) 1 else 2; def h = hk/p
def vl = arch_defvw / (p*width{T})
@for_mult_max{vl, w-vl} (i to w+(-w)%vl) {
modular_kernel_rev{h,p}{rp+i*hk, xp+i, ws}
}
}
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = { fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
# Scalar transpose defined in C # Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
@ -282,6 +343,11 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}} def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}}
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws} each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
} }
if (has_blend and w>=vl/2 and h==hs and (h%2==0 or w>=vl)) {
def tr = transpose_fixed_height{rp, xp, w, ws, .}
def hs = replicate{{i} => i!=(if (width{T}==64) 4 else kh), tup{3,4,5,6,7}}
each{{hk} => if (h==hk) { tr{hk}; return{} }, hs}
}
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) { if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) { @for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs} kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}