From e75b63831fa8fe94cd379086e10369fc700fd401 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Tue, 5 Nov 2024 21:39:15 -0500 Subject: [PATCH] Clean up dispatching for the various transpose kernel methods --- src/singeli/src/transpose.singeli | 105 ++++++++++++++++++------------ 1 file changed, 65 insertions(+), 40 deletions(-) diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index aeba7e60..fc970eb6 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -285,8 +285,15 @@ def get_modperm_lane_shuf{c} = { } def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} } +def loop_fixed_height{xp, rp, w, k, st, kern} = { + @for_mult_max{k, w-k} (i to w+(-w)%k) kern{xp+i, rp+i*st} +} +def loop_fixed_width{xp, rp, h, k, st, kern} = { + @for_mult_max{k, h-k} (i to h+(-h)%k) kern{xp+i*st, rp+i} +} + # Transpose a contiguous kernel of width w*p from x to r with stride rst -def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = { +def modular_kernel{w,p if w%2==1 and 2%p==0}{xp:*T, rp0:*T, rst:(u64)} = { def h = arch_defvw / 8 def ih = iota{h}; def iw = iota{w} def I = [h]u8; def V = I @@ -316,7 +323,7 @@ def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = { } each{perm_store, xs} } -def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = { +def modular_kernel{2,2}{xp:*T, rp0:*T, rst:(u64)} = { def h = arch_defvw / 8 def V = [h]u8; def U = n_h{V} def ih = iota{h}%16; def e = width{T} / 8 @@ -337,14 +344,15 @@ def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = { def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = { def p = if (wk%2) 1 else 2; def w = wk/p def vl = arch_defvw / (p*width{T}) - @for_mult_max{vl, h-vl} (i to h+(-h)%vl) { - modular_kernel{w,p}{rp+i, xp+i*wk, hs} - } + loop_fixed_width{xp, rp, h, vl, wk, modular_kernel{w,p}{., ., hs}} +} +def transpose_fixed_width{rp:*T, xp:*T, 2, h, hs} = { + uninterleave{rp, rp+hs, xp, h} } # Transpose a kernel of height w*p from x with stride xst to contiguous r # w and h are named for the result, not argument, to match modular_kernel -def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = { +def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{xp0:*T, rp:*T, xst:(u64)} = { def h = arch_defvw / 8 def ih = iota{h}; def iw = iota{w} def I = [h]u8; def V = I @@ -378,11 +386,11 @@ def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = { # Permute vectors and store each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}} } -def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = { +def modular_kernel_rev{2,2}{xp0:*T, rp:*T, xst:(u64)} = { def V = [arch_defvw / width{T}]T; def U = n_h{V} xl := @unroll (i to 2) { - xp := xp0 + i*rst - x := load2{*U~~xp, *U~~(xp + 2*rst)} + xp := xp0 + i*xst + x := load2{*U~~xp, *U~~(xp + 2*xst)} if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}} } xs := unpack_typed{...unpack_typed{...xl}} @@ -391,45 +399,62 @@ def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = { def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = { def p = if (hk%2) 1 else 2; def h = hk/p def vl = arch_defvw / (p*width{T}) - @for_mult_max{vl, w-vl} (i to w+(-w)%vl) { - modular_kernel_rev{h,p}{rp+i*hk, xp+i, ws} - } + loop_fixed_height{xp, rp, w, vl, hk, modular_kernel_rev{h,p}{., ., ws}} +} +def transpose_fixed_height{rp:*T, xp:*T, w, ws, 2} = { + interleave{T}(*void~~rp, *void~~xp, *void~~(xp+ws), w) } fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = { + rp:*T = *T~~r0 + xp:*T = *T~~x0 + def wT = width{T} + def vl = arch_defvw / wT + # Transposes with code dedicated to a particular width or height + def try_fixed_dim{tr, l, lst, nl, l_max} = { + def incl{l} = if (k>4) 1 else l!=4 + if (l=vl/2 and (l%2==0 or nl>=vl)) { + def try{ls} = { + def i = length{ls}>>1 + if (l < select{ls,i}) try{slice{ls, 0,i}} else try{slice{ls, i}} + } + def try{{lk}} = tr{lk} + try{replicate{{i} => imax{4,8/wT} + def w_max = 1 + (if (use_part_w) max{4, k/2-1} else 7) + if (has_simd and w < max{w_max, k}) { + try_fixed_dim{transpose_fixed_width {rp, xp, ., h, hs}, w, ws, h, w_max} + if (use_part_w and h>=kh and w>=k/2 and w8 and w>=k and h>=kh/2) { + loop_fixed_height{xp, rp, w, k, hs, kernel_part_h{h}{., ., k, kh, ws, hs}} + return{} + } + } # Scalar transpose defined in C def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs} - - rp:*T = *T~~r0 - xp:*T = *T~~x0 - def vl = arch_defvw / width{T} - def use_kpart = width{T}<=16 and k>max{4,8/width{T}} - def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'} - if (has_blend and h>=vl/2 and w==ws and (w%2==0 or h>=vl)) { - def tr = transpose_fixed_width{rp, xp, ., h, hs} - def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}} - each{{wk} => if (w==wk) { tr{wk}; return{} }, ws} - } - if (has_blend and w>=vl/2 and h==hs and (h%2==0 or w>=vl)) { - def tr = transpose_fixed_height{rp, xp, w, ws, .} - def hs = replicate{{i} => i!=(if (width{T}==64) 4 else kh), tup{3,4,5,6,7}} - each{{hk} => if (h==hk) { tr{hk}; return{} }, hs} - } - if (has_simd and use_kpart and h>=kh and w>=k/2 and w=k and h>=kh/2 and h=k and h>=k) { + # Full kernels + # May have w2D transpose with w!=ws or h!=hs + if (has_simd and k!=0 and w>=k and h>=k) { transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs} } else { - if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w) - else if (w==2 and w==ws) uninterleave{rp, rp+hs, xp, h} - else call_base{rp, xp, w, h} + call_base{rp, xp, w, h} } }