From 780bfdfa0bdc8660b7affa026129c58b975d8734 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Thu, 31 Oct 2024 19:16:50 -0400 Subject: [PATCH] Kernel transpose with overreading and skipped writes, for short rows --- src/singeli/src/transpose.singeli | 39 ++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index c48dc9c1..016746bf 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -47,19 +47,46 @@ def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) { { 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}} {128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}} } +def store1of2{a:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) { + { 64} => storeLow{*u64~~a, 64, [2]u64~~v} + {128} => store{a, 0, T~~half{v,0}} +} def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}} def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x} def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}} def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x} # Transpose kernel of size kw,kh in size w,h array -def kernel{src:*T, dst:*T, kw, kh, w, h} = { +def kernel_part{part_w}{src:*T, dst:*T, kw, kh, w, h} = { def n = (kw*kh*width{T}) / arch_defvw # Number of vectors def xvs = load_k{[kw]T, src, n, w} def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square - store_k{[kh]T, dst, rvs, n, h} + def stores = store_k{[kh]T, ..., h} + if (same{part_w, 0}) { + stores{dst, rvs, n} + } else { + # Write w results, kw/2 <= n < kw + d := dst + def vd = kw / n # Number of writes for each output vector (1 or 2) + def store_slice{rv, len} = { + stores{d, slice{rv,0,len}, len} + d += len*vd*h + } + store_slice{rvs, n/2} # Unconditionally store first half + rt := slice{rvs,n/2} # Remaining tail + def wtail{b} = { + if ((part_w & (vd*b)) != 0) { + store_slice{rt, b} + slice{rt,0,b} = slice{rt,b,2*b} + } + if (b>1) wtail{b/2} + } + wtail{n/4} + if (vd>1 and (part_w & 1) != 0) store1of2{*[kh]T~~d, select{rt,0}} + } } +def kernel = kernel_part{0} def for_mult{k}{vars,begin,end,iter} = { @@ -220,9 +247,15 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi rp:*T = *T~~r0 xp:*T = *T~~x0 def vl = arch_defvw / width{T} + if (has_simd and k>max{4,8/width{T}} and h>=kh and w>=k/2 and w=vl and w==ws) { def tr = transpose_with_modular{rp, xp, ., h, hs} - each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{7}} # 3 to 15 + each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{3}} # 3 to 7 } if (has_simd and k!=0 and w>=k and h>=k) { transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}