Clean up dispatching for the various transpose kernel methods
This commit is contained in:
parent
2d75c6c535
commit
e75b63831f
@ -285,8 +285,15 @@ def get_modperm_lane_shuf{c} = {
|
||||
}
|
||||
def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
|
||||
|
||||
def loop_fixed_height{xp, rp, w, k, st, kern} = {
|
||||
@for_mult_max{k, w-k} (i to w+(-w)%k) kern{xp+i, rp+i*st}
|
||||
}
|
||||
def loop_fixed_width{xp, rp, h, k, st, kern} = {
|
||||
@for_mult_max{k, h-k} (i to h+(-h)%k) kern{xp+i*st, rp+i}
|
||||
}
|
||||
|
||||
# Transpose a contiguous kernel of width w*p from x to r with stride rst
|
||||
def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||
def modular_kernel{w,p if w%2==1 and 2%p==0}{xp:*T, rp0:*T, rst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def ih = iota{h}; def iw = iota{w}
|
||||
def I = [h]u8; def V = I
|
||||
@ -316,7 +323,7 @@ def modular_kernel{w,p if w%2==1 and 2%p==0}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||
}
|
||||
each{perm_store, xs}
|
||||
}
|
||||
def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||
def modular_kernel{2,2}{xp:*T, rp0:*T, rst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def V = [h]u8; def U = n_h{V}
|
||||
def ih = iota{h}%16; def e = width{T} / 8
|
||||
@ -337,14 +344,15 @@ def modular_kernel{2,2}{rp0:*T, xp:*T, rst:(u64)} = {
|
||||
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
||||
def p = if (wk%2) 1 else 2; def w = wk/p
|
||||
def vl = arch_defvw / (p*width{T})
|
||||
@for_mult_max{vl, h-vl} (i to h+(-h)%vl) {
|
||||
modular_kernel{w,p}{rp+i, xp+i*wk, hs}
|
||||
}
|
||||
loop_fixed_width{xp, rp, h, vl, wk, modular_kernel{w,p}{., ., hs}}
|
||||
}
|
||||
def transpose_fixed_width{rp:*T, xp:*T, 2, h, hs} = {
|
||||
uninterleave{rp, rp+hs, xp, h}
|
||||
}
|
||||
|
||||
# Transpose a kernel of height w*p from x with stride xst to contiguous r
|
||||
# w and h are named for the result, not argument, to match modular_kernel
|
||||
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = {
|
||||
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{xp0:*T, rp:*T, xst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def ih = iota{h}; def iw = iota{w}
|
||||
def I = [h]u8; def V = I
|
||||
@ -378,11 +386,11 @@ def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{rp:*T, xp0:*T, xst:(u64)} = {
|
||||
# Permute vectors and store
|
||||
each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}}
|
||||
}
|
||||
def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = {
|
||||
def modular_kernel_rev{2,2}{xp0:*T, rp:*T, xst:(u64)} = {
|
||||
def V = [arch_defvw / width{T}]T; def U = n_h{V}
|
||||
xl := @unroll (i to 2) {
|
||||
xp := xp0 + i*rst
|
||||
x := load2{*U~~xp, *U~~(xp + 2*rst)}
|
||||
xp := xp0 + i*xst
|
||||
x := load2{*U~~xp, *U~~(xp + 2*xst)}
|
||||
if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}}
|
||||
}
|
||||
xs := unpack_typed{...unpack_typed{...xl}}
|
||||
@ -391,45 +399,62 @@ def modular_kernel_rev{2,2}{rp:*T, xp0:*T, rst:(u64)} = {
|
||||
def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = {
|
||||
def p = if (hk%2) 1 else 2; def h = hk/p
|
||||
def vl = arch_defvw / (p*width{T})
|
||||
@for_mult_max{vl, w-vl} (i to w+(-w)%vl) {
|
||||
modular_kernel_rev{h,p}{rp+i*hk, xp+i, ws}
|
||||
}
|
||||
loop_fixed_height{xp, rp, w, vl, hk, modular_kernel_rev{h,p}{., ., ws}}
|
||||
}
|
||||
def transpose_fixed_height{rp:*T, xp:*T, w, ws, 2} = {
|
||||
interleave{T}(*void~~rp, *void~~xp, *void~~(xp+ws), w)
|
||||
}
|
||||
|
||||
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
def wT = width{T}
|
||||
def vl = arch_defvw / wT
|
||||
# Transposes with code dedicated to a particular width or height
|
||||
def try_fixed_dim{tr, l, lst, nl, l_max} = {
|
||||
def incl{l} = if (k>4) 1 else l!=4
|
||||
if (l<l_max and incl{l} and l==lst) {
|
||||
if (l == 2) { tr{2}; return{} }
|
||||
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
|
||||
if (has_blend and nl>=vl/2 and (l%2==0 or nl>=vl)) {
|
||||
def try{ls} = {
|
||||
def i = length{ls}>>1
|
||||
if (l < select{ls,i}) try{slice{ls, 0,i}} else try{slice{ls, i}}
|
||||
}
|
||||
def try{{lk}} = tr{lk}
|
||||
try{replicate{{i} => i<l_max and incl{i}, slice{iota{8},3}}}
|
||||
return{}
|
||||
}
|
||||
}
|
||||
}
|
||||
# Small width: fixed, or over-reading partial kernel
|
||||
def use_part_w = wT<=16 and k>max{4,8/wT}
|
||||
def w_max = 1 + (if (use_part_w) max{4, k/2-1} else 7)
|
||||
if (has_simd and w < max{w_max, k}) {
|
||||
try_fixed_dim{transpose_fixed_width {rp, xp, ., h, hs}, w, ws, h, w_max}
|
||||
if (use_part_w and h>=kh and w>=k/2 and w<k) {
|
||||
loop_fixed_width{xp, rp, h, kh, ws, kernel_part{w}{., ., k, kh, ws, hs}}
|
||||
return{}
|
||||
}
|
||||
}
|
||||
# Small height: fixed, or kernel with overlapping writes
|
||||
# Overlapping is slower than over-reading, so it's only used when needed
|
||||
if (has_simd and h < max{8, k}) {
|
||||
try_fixed_dim{transpose_fixed_height{rp, xp, w, ws, .}, h, hs, w, 8}
|
||||
if (k>8 and w>=k and h>=kh/2) {
|
||||
loop_fixed_height{xp, rp, w, k, hs, kernel_part_h{h}{., ., k, kh, ws, hs}}
|
||||
return{}
|
||||
}
|
||||
}
|
||||
# Scalar transpose defined in C
|
||||
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
||||
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
|
||||
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
def vl = arch_defvw / width{T}
|
||||
def use_kpart = width{T}<=16 and k>max{4,8/width{T}}
|
||||
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
|
||||
if (has_blend and h>=vl/2 and w==ws and (w%2==0 or h>=vl)) {
|
||||
def tr = transpose_fixed_width{rp, xp, ., h, hs}
|
||||
def ws = replicate{{i} => i!=(if (width{T}==64) 4 else k) and (i==4 or not (use_kpart and i>=k/2)), tup{3,4,5,6,7}}
|
||||
each{{wk} => if (w==wk) { tr{wk}; return{} }, ws}
|
||||
}
|
||||
if (has_blend and w>=vl/2 and h==hs and (h%2==0 or w>=vl)) {
|
||||
def tr = transpose_fixed_height{rp, xp, w, ws, .}
|
||||
def hs = replicate{{i} => i!=(if (width{T}==64) 4 else kh), tup{3,4,5,6,7}}
|
||||
each{{hk} => if (h==hk) { tr{hk}; return{} }, hs}
|
||||
}
|
||||
if (has_simd and use_kpart and h>=kh and w>=k/2 and w<k) {
|
||||
@for_mult_max{kh, h-kh} (i to h+(-h)%kh) {
|
||||
kernel_part{w}{xp+i*ws, rp+i, k, kh, ws, hs}
|
||||
}
|
||||
} else if (hasarch{'AVX2'} and width{T}==8 and w>=k and h>=kh/2 and h<kh) {
|
||||
@for_mult_max{k, w-k} (i to w+(-w)%k) {
|
||||
kernel_part_h{h}{xp+i, rp+i*hs, k, kh, ws, hs}
|
||||
}
|
||||
} else if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
# Full kernels
|
||||
# May have w<k or h<k if not has_blend, or >2D transpose with w!=ws or h!=hs
|
||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||
} else {
|
||||
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)
|
||||
else if (w==2 and w==ws) uninterleave{rp, rp+hs, xp, h}
|
||||
else call_base{rp, xp, w, h}
|
||||
call_base{rp, xp, w, h}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user