SIMD transpose on 8-byte elements

This commit is contained in:
Marshall Lochbaum 2023-03-20 19:47:02 -04:00
parent a0e85db702
commit c0aaa6f615
2 changed files with 22 additions and 11 deletions

View File

@ -1245,7 +1245,8 @@ B reverse_c2(B t, B w, B x) {
#endif
#if SINGELI_X86_64
static NOINLINE void base_transpose_u32(u32* rp, u32* xp, u64 w, u64 h, u64 xo, u64 ro) { PLAINLOOP for(usz y=0;y<h;y++) NOVECTORIZE for(usz x=0;x<w;x++) rp[x*ro+y] = xp[y*xo+x]; }
static NOINLINE void base_transpose_i32(i32* rp, i32* xp, u64 w, u64 h, u64 xo, u64 ro) { PLAINLOOP for(usz y=0;y<h;y++) NOVECTORIZE for(usz x=0;x<w;x++) rp[x*ro+y] = xp[y*xo+x]; }
static NOINLINE void base_transpose_i64(i64* rp, i64* xp, u64 w, u64 h, u64 xo, u64 ro) { PLAINLOOP for(usz y=0;y<h;y++) NOVECTORIZE for(usz x=0;x<w;x++) rp[x*ro+y] = xp[y*xo+x]; }
#define SINGELI_FILE transpose
#include "../utils/includeSingeli.h"
#endif
@ -1342,7 +1343,11 @@ B transp_c1(B t, B x) {
if (w>=8 && h>=8) { u32* xp=tyany_ptr(x); u32* rp = m_tyarrp(&r,4,ia,el2t(xe)); simd_transpose_i32(rp, xp, w, h); break; }
#endif
{ u32* xp=tyany_ptr(x); u32* rp = m_tyarrp(&r,4,ia,el2t(xe)); PLAINLOOP for(usz y=0;y<h;y++) NOVECTORIZE for(usz x=0;x<w;x++) rp[x*h+y] = xp[xi++]; break; }
case el_f64: { f64* xp=f64any_ptr(x); f64* rp; r=m_f64arrp(&rp,ia); PLAINLOOP for(usz y=0;y<h;y++) NOVECTORIZE for(usz x=0;x<w;x++) rp[x*h+y] = xp[xi++]; break; }
case el_f64:
#if SINGELI_X86_64
if (w>=4 && h>=4) { f64* xp=f64any_ptr(x); f64* rp; r=m_f64arrp(&rp,ia); simd_transpose_i64(rp, xp, w, h); break; }
#endif
{ f64* xp=f64any_ptr(x); f64* rp; r=m_f64arrp(&rp,ia); PLAINLOOP for(usz y=0;y<h;y++) NOVECTORIZE for(usz x=0;x<w;x++) rp[x*h+y] = xp[xi++]; break; }
case el_B: { // can't be bothered to implement a bitarr transpose
B xf = getFillR(x);
B* xp = TO_BPTR(x);

View File

@ -28,21 +28,25 @@ def vtranspose{x & tuplen{x}==8 & type{tupsel{0,x}}==[8]i32 & hasarch{'X86_64'}}
merge{h{16b20}, h{16b31}}
}
def vtranspose{x & tuplen{x}==4 & type{tupsel{0,x}}==[4]i64 & hasarch{'X86_64'}} = {
def t1 = unpack_pass{1, x}
def t2pairs = pairs{2, t1}
def h{p} = each{{a,b}=>emit{[8]i32, '_mm256_permute2f128_si256', a,b,p}, ...t2pairs}
merge{h{16b20}, h{16b31}}
}
fn transpose{T}(r0:*void, x0:*void, w:u64, h:u64) : void = {
fn transpose{T, k}(r0:*void, x0:*void, w:u64, h:u64) : void = {
rp:*T = *T~~r0
xp:*T = *T~~x0
def VT = [k]T
def for_mult{k}{vars,begin,end,block} = {
assert{begin == 0}
@for (i to end/k) exec{k*i, vars, block}
}
# Kernel size
def k = 8
def VT = [k]i32
# Cache line info
def line_bytes = 64
def line_elts = line_bytes / (width{T}/8)
@ -68,7 +72,7 @@ fn transpose{T}(r0:*void, x0:*void, w:u64, h:u64) : void = {
def vt{i} = vtranspose{each{loadx, k*i + iota{k}}}
each{tup, ...each{vt, iota{line_vecs}}}
}
ro := tail{6, -u64~~r0} / 4 # Offset to align within cache line; assume elt-aligned
ro := tail{6, -u64~~r0} / (width{T}/8) # Offset to align within cache line; assume elt-aligned
wh := w*h
yn := h
if (ro != 0) {
@ -103,8 +107,10 @@ fn transpose{T}(r0:*void, x0:*void, w:u64, h:u64) : void = {
}
}
if (w%8) emit{void, 'base_transpose_u32', rp+h*(w-w%8), xp+ (w-w%8), w%8, h, w, h}
if (h%8) emit{void, 'base_transpose_u32', rp+ (h-h%8), xp+w*(h-h%8), w-w%8, h%8, w, h}
def base = if (T==i32) 'base_transpose_i32' else 'base_transpose_i64'
if (w%k) emit{void, base, rp+h*(w-w%k), xp+ (w-w%k), w%k, h, w, h}
if (h%k) emit{void, base, rp+ (h-h%k), xp+w*(h-h%k), w-w%k, h%k, w, h}
}
export{'simd_transpose_i32', transpose{u32}}
export{'simd_transpose_i32', transpose{i32, 8}}
export{'simd_transpose_i64', transpose{i64, 4}}