diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index 016746bf..e1ae9452 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -201,16 +201,18 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = { } # Transpose a contiguous width-w (w odd) kernel from x to r with stride rst -def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = { +def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = { + def h = arch_defvw / 8 def ih = iota{h}; def iw = iota{w} - def I = [h]u8 + def I = [h]u8; def V = I + def e = width{T} / 8 # Load a shape h,w slice of x, but consider as shape w,h - def xsp = each{load{*[h]T~~xp, .}, iw} + def xsp = each{load{*V~~xp, .}, iw} # Modular permutation of (reshaped argument) columns - xs := select{xsp, find_index{h*iw % w, iw}} + xs := select{xsp, find_index{h/e*iw % w, iw}} # Rotate each column by its index @unroll (kl to ceil_log2{w}) { def k = 1< cross{shuf{16, x, i}, i} } def perm_store{x} = { - store{*[h]T~~rp, 0, perm{x, mp}} + store{*V~~rp, 0, perm{x, mp}} rp += rst; mp += mi if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w } @@ -235,7 +237,7 @@ def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = { def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = { def vl = arch_defvw / width{T} @for_mult_max{vl, h-vl} (i to h+(-h)%vl) { - modular_kernel{wk,vl}{rp+i, xp+i*wk, hs} + modular_kernel{wk}{rp+i, xp+i*wk, hs} } } @@ -247,15 +249,17 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi rp:*T = *T~~r0 xp:*T = *T~~x0 def vl = arch_defvw / width{T} - if (has_simd and k>max{4,8/width{T}} and h>=kh and w>=k/2 and wmax{4,8/width{T}} + if (has_simd and use_kpart and h>=kh and w>=k/2 and w=vl and w==ws) { + if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) { def tr = transpose_with_modular{rp, xp, ., h, hs} - each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{3}} # 3 to 7 + def ws = replicate{{i} => (not use_kpart) or i < k/2, 3+2*iota{3}} # 3 to 7 + each{{wk} => if (w==wk) { tr{wk}; return{} }, ws} } if (has_simd and k!=0 and w>=k and h>=k) { transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}