From aefd2a4abdc4ad42b42d5b77e658f985824be1c8 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Sun, 27 Oct 2024 11:04:24 -0400 Subject: [PATCH] Fast 1-byte small odd width transpose with blends and modular permutation --- src/singeli/src/transpose.singeli | 44 +++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index 4c8fa943..c48dc9c1 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -173,6 +173,45 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = { } } +# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst +def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = { + def ih = iota{h}; def iw = iota{w} + def I = [h]u8 + # Load a shape h,w slice of x, but consider as shape w,h + def xsp = each{load{*[h]T~~xp, .}, iw} + # Modular permutation of (reshaped argument) columns + xs := select{xsp, find_index{h*iw % w, iw}} + # Rotate each column by its index + @unroll (kl to ceil_log2{w}) { def k = 1< cross{shuf{16, x, i}, i} + } + def perm_store{x} = { + store{*[h]T~~rp, 0, perm{x, mp}} + rp += rst; mp += mi + if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w + } + each{perm_store, xs} +} + +def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = { + def vl = arch_defvw / width{T} + @for_mult_max{vl, h-vl} (i to h+(-h)%vl) { + modular_kernel{wk,vl}{rp+i, xp+i*wk, hs} + } +} + fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = { # Scalar transpose defined in C def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64' @@ -180,6 +219,11 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi rp:*T = *T~~r0 xp:*T = *T~~x0 + def vl = arch_defvw / width{T} + if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and T==i8 and h>=vl and w==ws) { + def tr = transpose_with_modular{rp, xp, ., h, hs} + each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{7}} # 3 to 15 + } if (has_simd and k!=0 and w>=k and h>=k) { transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs} } else {