From 8bc37098dcc71173b5ebddae18d5d11bd7a531cb Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 1 Nov 2024 21:31:50 -0400 Subject: [PATCH] Modular kernel transpose for odd*2 (so, just 6) --- src/singeli/src/transpose.singeli | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index e1ae9452..fb4ca53f 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -200,8 +200,9 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = { } } -# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst -def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = { +# Transpose a contiguous kernel of width w*p from x to r with stride rst +def modular_kernel{w,p}{rp0:*T, xp:*T, rst:(u64)} = { + assert{w%2 == 1}; lb{p} # Odd times power of two def h = arch_defvw / 8 def ih = iota{h}; def iw = iota{w} def I = [h]u8; def V = I @@ -220,14 +221,19 @@ def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = { } # Modular permutation of rows, and write to result rp := rp0 - mp := make{I, ih%e + (ih - ih%e)%16*w%h}; mi := I**e + mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}; mi := I**e def perm = if (h==16) shuf else { c := I**16 def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c} - {x, i} => cross{shuf{16, x, i}, i} + def q = match (p) { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} } + if (p>1) mp = q{mp} + {x, i} => q{cross{shuf{16, x, i}, i}} } def perm_store{x} = { - store{*V~~rp, 0, perm{x, mp}} + match (p) { + {1} => store{*V~~rp, 0, perm{x, mp}} + {2} => { def U = [h/2]u8; store2{*U~~rp, *U~~(rp+w*rst), perm{x, mp}} } + } rp += rst; mp += mi if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w } @@ -235,9 +241,11 @@ def modular_kernel{w}{rp0:*T, xp:*T, rst:(u64)} = { } def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = { - def vl = arch_defvw / width{T} + def odd_part{w} = if (w%2) w else odd_part{w/2} + def w = odd_part{wk}; def p = wk/w + def vl = arch_defvw / (p*width{T}) @for_mult_max{vl, h-vl} (i to h+(-h)%vl) { - modular_kernel{wk}{rp+i, xp+i*wk, hs} + modular_kernel{w,p}{rp+i, xp+i*wk, hs} } } @@ -258,7 +266,7 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi } if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and h>=vl and w==ws) { def tr = transpose_with_modular{rp, xp, ., h, hs} - def ws = replicate{{i} => (not use_kpart) or i < k/2, 3+2*iota{3}} # 3 to 7 + def ws = replicate{{i} => (not use_kpart) or i < k/2, tup{3,5,6,7}} each{{wk} => if (w==wk) { tr{wk}; return{} }, ws} } if (has_simd and k!=0 and w>=k and h>=k) {