Fast 1-byte small odd width transpose with blends and modular permutation

This commit is contained in:
Marshall Lochbaum 2024-10-27 11:04:24 -04:00
parent 5bdee886cc
commit aefd2a4abd

View File

@ -173,6 +173,45 @@ fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
}
}
# Transpose a contiguous width-w (w odd) kernel from x to r with stride rst
def modular_kernel{w,h}{rp0:*T==i8, xp:*T, rst:(u64)} = {
def ih = iota{h}; def iw = iota{w}
def I = [h]u8
# Load a shape h,w slice of x, but consider as shape w,h
def xsp = each{load{*[h]T~~xp, .}, iw}
# Modular permutation of (reshaped argument) columns
xs := select{xsp, find_index{h*iw % w, iw}}
# Rotate each column by its index
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
def m = make{I, 0xff * ((ih % w) & k != 0)}
def bl{x,y} = { x = homBlend{x,y,m} }
x0 := select{xs, 0}
def xord = select{xs, -k*iw % w}
each{bl, xord, shiftleft{xord, tup{x0}}}
}
# Modular permutation of rows, and write to result
rp := rp0
mp := make{I, w*(ih%16) % h}; mi := I**1
def perm = if (h==16) shuf else {
c := I**16
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
{x, i} => cross{shuf{16, x, i}, i}
}
def perm_store{x} = {
store{*[h]T~~rp, 0, perm{x, mp}}
rp += rst; mp += mi
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
}
each{perm_store, xs}
}
def transpose_with_modular{rp:*T, xp:*T, wk, h, hs} = {
def vl = arch_defvw / width{T}
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
modular_kernel{wk,vl}{rp+i, xp+i*wk, hs}
}
}
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
# Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
@ -180,6 +219,11 @@ fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : voi
rp:*T = *T~~r0
xp:*T = *T~~x0
def vl = arch_defvw / width{T}
if ((hasarch{'SSE4.1'} or hasarch{'AARCH64'}) and T==i8 and h>=vl and w==ws) {
def tr = transpose_with_modular{rp, xp, ., h, hs}
each{{wk} => if (w==wk) { tr{wk}; return{} }, 3+2*iota{7}} # 3 to 15
}
if (has_simd and k!=0 and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
} else {