From abb96cb18a6db8dc36609be14505a161e3d06cd8 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Sun, 3 Nov 2024 20:20:31 -0500 Subject: [PATCH] Fixed-height transposes, more or less reverse of fixed-width --- src/singeli/src/transpose.singeli | 102 ++++++++++++++++++++++++------ 1 file changed, 84 insertions(+), 18 deletions(-) diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index eca762d0..ed724fda 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -200,9 +200,25 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = { } } +# Utilities for kernels based on modular permutation +def rotcol{xs, mg:I} = { + def w = length{xs} + @unroll (kl to ceil_log2{w}) { def k = 1< cross{shuf{16, x, i}, i} +} +def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} } + # Transpose a contiguous kernel of width w*p from x to r with stride rst -def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = { - assert{w%2 == 1}; lb{p} # Odd times power of two +def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = { def h = arch_defvw / 8 def ih = iota{h}; def iw = iota{w} def I = [h]u8; def V = I @@ -212,22 +228,15 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = { # Modular permutation of (reshaped argument) columns xs := select{xsp, find_index{h/e*iw % w, iw}} # Rotate each column by its index - @unroll (kl to ceil_log2{w}) { def k = 1<({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} } - if (p>1) mp = q{mp} - {x, i} => q{cross{shuf{16, x, i}, i}} + def sh = get_modperm_lane_shuf{I**16} + def q = tr_quads{p}; if (p>1) mp = q{mp} + {x, i} => q{sh{x, i}} } def perm_store{x} = { match (p) { @@ -241,7 +250,7 @@ def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = { } def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = { def h = arch_defvw / 8 - def V = [h]u8 + def V = [h]u8; def U = n_h{V} def ih = iota{h}%16; def e = width{T} / 8 # Permutation to unzip by 4 within each lane uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e} @@ -253,12 +262,10 @@ def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = { def xsp = each{load{*V~~xp, .}, iota{2}} xs := each{proc, each{shuf{16, ., uz}, xsp}} @unroll (i to 2) { - def U = [h/2]u8 rp := rp0 + st*i*rst store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}} } } - def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = { def p = if (wk%2) 1 else 2; def w = wk/p def vl = arch_defvw / (p*width{T}) @@ -267,6 +274,60 @@ def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = { } } +# Transpose a kernel of height w*p from x with stride xst to contiguous r +# w and h are named for the result, not argument, to match modular_kernel +def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = { + def h = arch_defvw / 8 + def ih = iota{h}; def iw = iota{w} + def I = [h]u8; def V = I + def e = width{T} / 8 + # Read rows, modular permutation on each + def rotbit{x, l,m,h} = x%l + (x-x%l)*(h/m)%h + x//m*l + def wi = w + 2 * ((w-1) + (w&2)) # Inverse mod 32 + def mpd = rotbit{ih%e + (ih - ih%e)%16*wi%h, e,e*p,h} + mp := make{I, if (h==16 or p==1) mpd else rotbit{mpd, 8,16,h}} + def rot_mp = { + def rot_lane = shuf{16, ., make{I, (ih-e)%16}} + def cross = if (h==16) ({x}=>x) else ^{make{I, 16*(ih%16 cross{rot_lane{mp}} + } + def perm = if (h==16) shuf else { + def sh = get_modperm_lane_shuf{I**16} + def q = tr_quads{p} + {x, i} => sh{q{x}, i} + } + xp := xp0 + xs := @collect (w) { + x := match (p) { + {1} => perm{load{*V~~xp, 0}, mp} + {2} => { def U = [h/2]u8; perm{load2{*U~~xp, *U~~(xp+w*xst)}, mp} } + } + xp += xst; mp = rot_mp{mp} + x + } + # Rotate each column by its index + rotcol{xs, make{I, (ih // e) % w}} + # Permute vectors and store + each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}} +} +def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = { + def V = [arch_defvw / width{T}]T; def U = n_h{V} + xl := @unroll (i to 2) { + xp := xp0 + i*rst + x := load2{*U~~xp, *U~~(xp + 2*rst)} + if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}} + } + xs := unpack_typed{...unpack_typed{...xl}} + each{store{*V~~rp, ., .}, iota{2}, each{~~{V,.},xs}} +} +def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = { + def p = if (hk%2) 1 else 2; def h = hk/p + def vl = arch_defvw / (p*width{T}) + @for_mult_max{vl, w-vl} (i to w+(-w)%vl) { + modular_kernel_rev{h,p}{rp+i*hk, xp+i, ws} + } +} + fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = { # Scalar transpose defined in C def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' @@ -282,6 +343,11 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}} each{{wk} => if (w==wk) { tr{wk}; return{} }, ws} } + if (has_blend and w>=vl/2 and h==hs and (h%2==0 or w>=vl)) { + def tr = transpose_fixed_height{rp, xp, w, ws, .} + def hs = replicate{{i} => i!=(if (width{T}==64) 4 else kh), tup{3,4,5,6,7}} + each{{hk} => if (h==hk) { tr{hk}; return{} }, hs} + } if (has_simd and use_kpart and h>=kh and w>=k/2 and w