diff --git a/src/singeli/src/transpose.singeli b/src/singeli/src/transpose.singeli index 9c13eff2..dbd0478a 100644 --- a/src/singeli/src/transpose.singeli +++ b/src/singeli/src/transpose.singeli @@ -22,16 +22,74 @@ def vtranspose{x & tuplen{x}==8 & type{tupsel{0,x}}==[8]i32 & hasarch{'X86_64'}} fn transpose{T}(r0:*void, x0:*void, w:u64, h:u64) : void = { rp:*T = *T~~r0 xp:*T = *T~~x0 + + def for_mult{k}{vars,begin,end,block} = { + assert{begin == 0} + @for (i to end/k) exec{k*i, vars, block} + } + + # Kernel size + def k = 8 + def VT = [k]i32 + + # Cache line info + def line_bytes = 64 + def line_elts = line_bytes / (width{T}/8) + def line_vecs = line_bytes / (width{VT}/8) - @for (y to h/8) { - @for (x to w/8) { - def VT = [8]i32 - xpo:= xp + y*8*w + x*8 - rpo:= rp + x*8*h + y*8 - def xvs = each{{i}=>load{*VT~~(xpo+i*w), 0}, iota{vcount{VT}}} - def rvs = vtranspose{xvs} - each{{i,v}=>store{*VT~~(rpo+i*h), 0, v}, iota{vcount{VT}}, rvs} - } + if (h&(line_elts-1) != 0) { + @for_mult{k} (y to h) { + @for_mult{k} (x to w) { + xpo:= xp + y*w + x + rpo:= rp + x*h + y + def xvs = each{{i}=>load{*VT~~(xpo+i*w), 0}, iota{vcount{VT}}} + def rvs = vtranspose{xvs} + each{{i,v}=>store{*VT~~(rpo+i*h), 0, v}, iota{vcount{VT}}, rvs} + } + } + } else { + # Result rows are aligned with each other so it's possible to + # write a full cache line at a time + # This case is here to mitigate cache associativity problems at + # at multiples of 256 or so, but it's faster whenever it applies + def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs} + def get_lines{loadx} = { + def vt{i} = vtranspose{each{loadx, k*i + iota{k}}} + each{tup, ...each{vt, iota{line_vecs}}} + } + ro := tail{6, -u64~~r0} / 4 # Offset to align within cache line; assume elt-aligned + wh := w*h + yn := h + if (ro != 0) { + ra := line_elts - ro + y := h - ra + rpo := rp + y # Cache aligned + rpe := rpo + (w-1)*h + # Part of first and last result row aren't covered by the split loop + def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, w*i}} + trtail{rp, xp, ro} + trtail{rpe, xp + y*w + w-1, ra} + # Transpose first few rows and last few rows together + @for_mult{k} (x to w) { + o := w*y + x + def loadx{_} = { + l:=load{*VT~~(xp+o)} + o+=w; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1 + l + } + def rls = get_lines{loadx} # 4 rows of 2 vectors each + each{{i,v} => {if (i<3 or rpo load{*VT~~(xpo+i*w), 0}} + each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls} + } + } } if (w%8) emit{void, 'base_transpose_u32', rp+h*(w-w%8), xp+ (w-w%8), w%8, h, w, h}