Merge pull request #123 from mlochbaum/modperm
CPU-sized select single column and transpose with modular permutations
This commit is contained in:
commit
163853439e
@ -1 +1 @@
|
||||
Subproject commit d432cb710911457169da5342d27ce8adffd5dd1a
|
||||
Subproject commit 17c512727dbcb6d58a2adadac4661bc9c43920d2
|
||||
@ -219,8 +219,8 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
|
||||
|
||||
|
||||
// fast special-case implementations
|
||||
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
|
||||
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
|
||||
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf); // from select.c
|
||||
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x
|
||||
ur xr = RNK(x);
|
||||
assert(xr>1 && k<xr);
|
||||
usz* xsh = SH(x);
|
||||
@ -228,58 +228,15 @@ static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind
|
||||
usz l = xsh[k];
|
||||
assert(0<=ind && ind<l);
|
||||
assert(cam*l*csz == IA(x));
|
||||
Arr* ra;
|
||||
usz take = leaf? 1 : csz;
|
||||
if (l==1 && take==csz) {
|
||||
ra = cpyWithShape(incG(x));
|
||||
arr_shErase(ra, 1);
|
||||
} else {
|
||||
u8 xe = TI(x,elType);
|
||||
u8 ewl= elwBitLog(xe);
|
||||
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
|
||||
usz ria = cam*take;
|
||||
if (xl>=7 || (xl<3 && xl>0)) { // generic case
|
||||
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
|
||||
usz jump = l * csz;
|
||||
usz xi = take*ind;
|
||||
usz ri = 0;
|
||||
for (usz i = 0; i < cam; i++) {
|
||||
mut_copyG(rm, ri, x, xi, take);
|
||||
xi+= jump;
|
||||
ri+= take;
|
||||
}
|
||||
ra = mut_fp(rm);
|
||||
} else if (xe==el_B) {
|
||||
assert(take == 1);
|
||||
SGet(x)
|
||||
HArr_p rp = m_harrUv(ria);
|
||||
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
|
||||
NOGC_E; ra = (Arr*)rp.c;
|
||||
} else {
|
||||
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
|
||||
void* xp = tyany_ptr(x);
|
||||
switch(xl) {
|
||||
case 0:
|
||||
#if SINGELI
|
||||
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
|
||||
else
|
||||
#endif
|
||||
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
|
||||
break;
|
||||
case 3: PLAINLOOP for (usz i=0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
|
||||
case 4: PLAINLOOP for (usz i=0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
|
||||
case 5: PLAINLOOP for (usz i=0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
|
||||
case 6: PLAINLOOP for (usz i=0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
B r = select_cells_single(ind, x, cam, l, csz, leaf);
|
||||
Arr* ra = a(r);
|
||||
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
|
||||
if (rsh) {
|
||||
shcpy(rsh, xsh, k);
|
||||
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
|
||||
}
|
||||
decG(x);
|
||||
return taga(ra);
|
||||
return r;
|
||||
}
|
||||
|
||||
static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores
|
||||
|
||||
@ -40,8 +40,15 @@
|
||||
// Sparse initialization if 𝕨 is much smaller than 𝕩
|
||||
// COULD call Mark Firsts (∊) for very short 𝕨 to avoid allocation
|
||||
|
||||
// Select Cells - inds⊸⊏⎉1 x
|
||||
// Squeeze indices if too wide for given x
|
||||
// Select Cells - inds⊸⊏⎉1 𝕩
|
||||
// Squeeze indices if too wide for given 𝕩
|
||||
// Single index: (also used for monadic ⊏˘ ⊣˝˘ ⊢˝˘)
|
||||
// Selecting a column of bits:
|
||||
// Row size <64: extract as with fold-cells
|
||||
// Selecting a column of 1, 2, 4, or 8-byte elements:
|
||||
// Short cells: pack vectors from 𝕩, or blend and permute
|
||||
// Long cells: dedicated scalar loop for each type
|
||||
// Otherwise, loop with mutable copy
|
||||
// Boolean indices:
|
||||
// Short inds and short cells: Widen to i8
|
||||
// Otherwise: bitsel call per cell
|
||||
@ -57,7 +64,7 @@
|
||||
// COULD generate full list of indices via arith
|
||||
// 1-element cells: use (≠inds)/⥊x after checking ∧´inds∊0‿¯1
|
||||
// Used for ⌽⎉1
|
||||
// SHOULD use for atom⊸⊏⎉k, /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
|
||||
// SHOULD use for /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
|
||||
|
||||
#include "../core.h"
|
||||
#include "../utils/talloc.h"
|
||||
@ -575,6 +582,62 @@ static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same e
|
||||
|
||||
B slash_c2(B, B, B);
|
||||
B select_cells_base(B inds, B x0, ux csz, ux cam);
|
||||
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
|
||||
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
|
||||
|
||||
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf) { // ⥊ ind {leaf? <∘⊑; ⊏}˘ cam‿l‿csz ⥊ x
|
||||
usz take = leaf? 1 : csz;
|
||||
Arr* ra;
|
||||
if (l==1 && take==csz) {
|
||||
ra = cpyWithShape(incG(x));
|
||||
arr_shErase(ra, 1);
|
||||
} else {
|
||||
u8 xe = TI(x,elType);
|
||||
u8 ewl= elwBitLog(xe);
|
||||
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
|
||||
usz ria = cam*take;
|
||||
if (xl>=7 || (xl<3 && xl>0)) { // generic case
|
||||
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
|
||||
usz jump = l * csz;
|
||||
usz xi = take*ind;
|
||||
usz ri = 0;
|
||||
for (usz i = 0; i < cam; i++) {
|
||||
mut_copyG(rm, ri, x, xi, take);
|
||||
xi+= jump;
|
||||
ri+= take;
|
||||
}
|
||||
ra = mut_fp(rm);
|
||||
} else if (xe==el_B) {
|
||||
assert(take == 1);
|
||||
SGet(x)
|
||||
HArr_p rp = m_harrUv(ria);
|
||||
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
|
||||
NOGC_E; ra = (Arr*)rp.c;
|
||||
} else {
|
||||
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
|
||||
void* xp = tyany_ptr(x);
|
||||
if (xl == 0) {
|
||||
#if SINGELI
|
||||
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
|
||||
else
|
||||
#endif
|
||||
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
|
||||
} else {
|
||||
usz i0 = 0;
|
||||
#if SINGELI
|
||||
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
|
||||
#endif
|
||||
switch(xl) { default: UD;
|
||||
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
|
||||
case 4: PLAINLOOP for (usz i=i0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
|
||||
case 5: PLAINLOOP for (usz i=i0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
|
||||
case 6: PLAINLOOP for (usz i=i0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return taga(ra);
|
||||
}
|
||||
|
||||
#define CLZC(X) (64-(CLZ((u64)(X))))
|
||||
|
||||
@ -851,6 +914,11 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸
|
||||
ux in = IA(inds);
|
||||
if (in == 0) return taga(emptyArr(x, 1));
|
||||
u8 ie = TI(inds,elType);
|
||||
if (in == 1) {
|
||||
B w = IGetU(inds,0); if (!isF64(w)) goto generic;
|
||||
B r = select_cells_single(WRAP(o2i64(w), csz, thrF("⊏: Indexing out-of-bounds (%R∊𝕨, %s≡≠𝕩)", w, csz)), x, cam, csz, 1, false);
|
||||
decG(x); decG(inds); return r;
|
||||
}
|
||||
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
|
||||
inds = num_squeeze(inds);
|
||||
ie = TI(inds,elType);
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
// Transpose and Reorder Axes (⍉)
|
||||
|
||||
// Transpose
|
||||
// One length-2 axis: dedicated code
|
||||
// Boolean: pdep or emulation for height 2; pext for width 2
|
||||
// SHOULD use a generic implementation if BMI2 not present
|
||||
// SHOULD optimize other short lengths with pdep/pext and shuffles
|
||||
// Boolean 𝕩: convert to integer
|
||||
// pdep or emulation for height 2; pext for width 2
|
||||
// SHOULD switch to shuffles and modular permutation
|
||||
// SHOULD have bit matrix transpose kernel
|
||||
// CPU sizes: native or SIMD code
|
||||
// Large SIMD kernels used when they fit, overlapping for odd sizes
|
||||
// i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
|
||||
// COULD use half-width or smaller kernels to improve odd sizes
|
||||
// Scalar transpose or loop used for overhang of 1
|
||||
// SHOULD add NEON
|
||||
// SSE, NEON i8: 8×8 ; i16: 8×8; i32: 4×4; f64: scalar
|
||||
// AVX2 i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
|
||||
// COULD use half-width or smaller kernels to improve odd sizes
|
||||
// Scalar transpose or loop used for overhang of 1
|
||||
// Partial kernels for at least half-kernel height/width
|
||||
// Short width: over-read, then skip some writes
|
||||
// Short height, i8 AVX2 only: overlap input rows and output words
|
||||
// Smaller heights and widths: dedicated kernels
|
||||
// Zipping, unzipping, shuffling for powers of 2
|
||||
// Modular permutation for odd numbers, possibly times 2
|
||||
|
||||
// Reorder Axes
|
||||
// If 𝕨 indicates the identity permutation, return 𝕩
|
||||
|
||||
@ -74,6 +74,189 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
|
||||
export{'si_sum_f64', fold_assoc_0{f64,+}}
|
||||
|
||||
|
||||
def extract_column_pow2{T, x0, r0, nv, k} = {
|
||||
def V = [arch_defvw / width{T}]T
|
||||
xv := *V~~x0
|
||||
@for (r in *V~~r0 over i to nv) {
|
||||
xs := each{load{xv, .}, iota{k}}
|
||||
def unzip0{w} = if (not hasarch{'X86_64'}) {
|
||||
unzip{..., 0} # Sane instruction set
|
||||
} else {
|
||||
if (w <= 16) {
|
||||
# Pack instructions
|
||||
m := make{V, - (iota{vcount{V}}%k == 0)}
|
||||
xs = each{&{m, .}, xs} # Mask off high bits
|
||||
def D = el_m{V}
|
||||
{a, b} => packQ{D~~a, D~~b}
|
||||
} else {
|
||||
# Two-vector shuffles
|
||||
# Could also be used for 1/2-byte with ending gap >= 4 bytes,
|
||||
# less instructions but it doesn't seem faster
|
||||
def c = 128/w
|
||||
def sh = shuf{[c]ty_f{w}, ., 2*iota{c} % c}
|
||||
{...ab} => sh{ab}
|
||||
}
|
||||
}
|
||||
if (not hasarch{'X86_64'} or T != u16 or hasarch{'SSE4.1'}) {
|
||||
r = tree_fold{unzip0{width{T}}, xs}
|
||||
} else {
|
||||
# No unsigned saturation: sign-extend then use unsigned
|
||||
def D = [4]i32
|
||||
def f = tree_fold{unzip0{32}, .}
|
||||
def proc{hx} = {
|
||||
ri := D~~f{hx}
|
||||
top := D**(1<<15); m := D**(1<<16 - 1)
|
||||
(ri & m) | (D~~(ri&top == top) &~ m)
|
||||
}
|
||||
r = V~~packQ{...each{proc, split{k/2, xs}}}
|
||||
}
|
||||
if (width{V} > 128) { # Lane axis wasn't packed, need to shuffle to bottom
|
||||
def tr{E,a, r} = shuf{[1<<a]E, r, tr_iota{shiftright{a-1, iota{a}}}}
|
||||
def lc = k > 4
|
||||
if (lc) r = tr{u64,2, r}
|
||||
r = tr{ty_u{128/k}, lb{k} + (not lc), r}
|
||||
}
|
||||
xv += k
|
||||
}
|
||||
}
|
||||
|
||||
def extract_column_modperm{x0, r0, nv, l, el, vl} = {
|
||||
# Build modular permutations
|
||||
def V = [vl]u8
|
||||
def H = [vl/2]u16
|
||||
p2 := ctz{l}; l >>= p2 # Decompose into l<<p2 with odd l
|
||||
e := p2 + promote{ux,el} # Absorb into element size for most computation
|
||||
l8 := cast_i{u8, l}
|
||||
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
|
||||
elo:= V**(u8~~1<<e - 1)
|
||||
ie := iota{V} & elo
|
||||
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
|
||||
def mu16{k} = {
|
||||
k16 := H ** k
|
||||
prd := shuf{V~~(kmul * k16), 0,0}
|
||||
if (e == 0) prd += V~~(k16 << 8)
|
||||
(prd & V**(vl-1)) + ie
|
||||
}
|
||||
si := mu16{l8}
|
||||
sii := mu16{li}
|
||||
def swap_ms = if (vl == 16) ({x}=>x) else {
|
||||
ms := (V**16 & sii) == (V**16 &~ iota{V})
|
||||
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
|
||||
}
|
||||
|
||||
# Blend masks
|
||||
def mg = { # Iteration i should select where mg == V**i
|
||||
ss := (si < V**(l8<<e)) & (ie == V**0)
|
||||
vs := V**0xff - scan_assoc_id0{+}{ss}
|
||||
swap_ms{shuf{[16]u8, vs, sii}}
|
||||
}
|
||||
mgo := mg - V**(l8 & 3)
|
||||
mgm := (mgo - V**1) & V**3
|
||||
m4s := @collect (i to 3) mgm == V**i
|
||||
|
||||
# Main loop
|
||||
def loop{output} = {
|
||||
xv := *V~~x0
|
||||
rv := *V~~r0
|
||||
# Each iteration handles l vectors from x
|
||||
@for (i to nv<<p2) {
|
||||
# 1 or 3 initial vectors
|
||||
r := load{xv,0}; ++xv
|
||||
if ((l & 2) != 0) {
|
||||
def {m0, _, m2} = m4s
|
||||
re := homBlend{load{xv,0}, load{xv,1}, m2}
|
||||
r = homBlend{re, r, m0}
|
||||
xv += 2
|
||||
}
|
||||
# Then the rest in groups of 4
|
||||
mh := mgo
|
||||
@for (l/4) {
|
||||
{l0, ...ls} := each{load{xv,.}, iota{4}}
|
||||
r4 := fold{homBlend, l0, ls, m4s}
|
||||
r = homBlend{r, r4, mh < V**4}
|
||||
mh -= V**4; xv += 4
|
||||
}
|
||||
def write{r} = {
|
||||
store{rv, 0, shuf{[16]u8, r, si}}
|
||||
++rv
|
||||
}
|
||||
output{r, i, write}
|
||||
}
|
||||
}
|
||||
|
||||
# Handle odd and even strides separately
|
||||
if (p2 == 0) {
|
||||
loop{{r, i, write} => write{swap_ms{r}}}
|
||||
} else {
|
||||
# Store results
|
||||
def add_res = {
|
||||
ra := V**0 # Accumulator
|
||||
i := make{V, iota{vl}%16}
|
||||
o := V**(u8~~1<<el)
|
||||
bl := i & elo < o
|
||||
sh := i - o
|
||||
{r} => { ra = homBlend{shuf{[16]u8, ra, sh}, r, bl} }
|
||||
}
|
||||
il := make{V, iota{vl} % 16}
|
||||
# Shuffle to undo interleaving of add_res
|
||||
def __shr{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x >> sh)
|
||||
def __shl{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x << sh)
|
||||
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8,-sh}
|
||||
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8, sh}
|
||||
def uz_lane = {
|
||||
l := V**(u8~~1<<el - 1)
|
||||
h := V**(16 - u8~~16>>p2)
|
||||
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
|
||||
dz |= (il &~ (l | h))<<p2 # middle->high
|
||||
shuf{[16]u8, ., dz}
|
||||
}
|
||||
# Adjust modular permutation to apply after unzipping
|
||||
si = uz_lane{(si & V**(vl - u8~~1<<e)) >> p2}
|
||||
si ^= il &~ V**((16 - u8~~1<<e) >> p2)
|
||||
# Cross-lane follow-up
|
||||
def cross = if (vl == 16) { {x}=>x } else {
|
||||
si = shuf{[4]u64, si, 0,1,0,1}
|
||||
assert{p2 <= 2}
|
||||
cr := make{[8]u32, tr_iota{0,2,1}}
|
||||
if (p2 > 1) cr = make{[8]u32, tr_iota{2,0,1}}
|
||||
shuf{[8]u32, ., cr}
|
||||
}
|
||||
# Run, writing every 1<<p2 steps
|
||||
plo := usz~~1<<p2 - 1
|
||||
loop{{r, i, write} => {
|
||||
def ra = add_res{r}
|
||||
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
# Select one element out of every l, element width 1<<el bytes
|
||||
# Maximum of n result values, return actual number written
|
||||
fn extract_column(x0:*void, r0:*void, n:usz, l:usz, el:u8) : usz = {
|
||||
n <<= el
|
||||
def vl = arch_defvw / 8
|
||||
def thr = min{vl+2, 20}
|
||||
if ((not has_simd) or n < vl or l > usz~~thr>>el or l<<el >= thr) return{0}
|
||||
nv := n / vl
|
||||
if (has_simd and (l & (l-1)) == 0) {
|
||||
def try_unzip{T, k} = if (k < thr and l == k) {
|
||||
extract_column_pow2{T, x0, r0, nv, k}
|
||||
goto{'ret'}
|
||||
}
|
||||
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
|
||||
@unroll (ek to 4) if (el == ek) {
|
||||
def T = ty_u{8<<ek}
|
||||
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
|
||||
}
|
||||
return{0}
|
||||
setlabel{'ret'}
|
||||
} else if (hasarch{'SSE4.1'} or hasarch{'AARCH64'}) {
|
||||
extract_column_modperm{x0, r0, nv, l, el, vl}
|
||||
} else return{0}
|
||||
(usz~~vl>>el) * nv
|
||||
}
|
||||
|
||||
|
||||
# Short-row boolean folds: main challenge is bit packing
|
||||
def fold_rows_bit_lt64{
|
||||
op, run_loop2, run_loop4, pext_res, mult_in,
|
||||
@ -184,7 +367,7 @@ def fold_rows_bit_lt64{
|
||||
}
|
||||
}
|
||||
|
||||
fn select_rows_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
|
||||
fn extract_column_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
|
||||
assert{l < 64}; assert{o < l} # Row length, and offset within row
|
||||
def run_loop2{loop} = loop{{a,b} => a>>o}
|
||||
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
|
||||
@ -329,4 +512,5 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
|
||||
}
|
||||
export{'si_xor_rows_bit', xor_rows_bit}
|
||||
export{'si_or_rows_bit', or_rows_bit}
|
||||
export{'si_select_cells_bit_lt64', select_rows_bit_lt64}
|
||||
export{'si_select_cells_bit_lt64', extract_column_bit_lt64}
|
||||
export{'si_select_cells_byte', extract_column}
|
||||
|
||||
@ -693,7 +693,7 @@ def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
||||
def fixed_loop{k} = {
|
||||
assert{wv == k}
|
||||
while (1) {
|
||||
# e.g. 01234567 to 05316427 on each byte for k==3, ew==8
|
||||
# e.g. 01234567 to 03614725 on each byte for k==3, ew==8
|
||||
xv := get_perm_x{}
|
||||
# Overhang from previous 64-bit elements
|
||||
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
||||
|
||||
@ -74,16 +74,7 @@ export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', sca
|
||||
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
|
||||
|
||||
# Assumes identity is 0
|
||||
def scan_assoc{op} = {
|
||||
def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise
|
||||
def shl0{v:V, k==128 if hasarch{'AVX2'}} = {
|
||||
# Broadcast end of lane 0 to entire lane 1
|
||||
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
|
||||
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
|
||||
}
|
||||
prefix_byshift{op, shl0}
|
||||
}
|
||||
def scan_plus = scan_assoc{+}
|
||||
def scan_plus = scan_assoc_id0{+}
|
||||
|
||||
# Associative scan
|
||||
def scan_assoc_0 = scan_scal
|
||||
|
||||
@ -66,3 +66,13 @@ def make_scan_idem{(f64), op, up} = {
|
||||
sc
|
||||
}
|
||||
def make_scan_idem{T, op} = make_scan_idem{T, op, 1}
|
||||
|
||||
def scan_assoc_id0{op} = {
|
||||
def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise
|
||||
def shl0{v:V, k==128 if hasarch{'AVX2'}} = {
|
||||
# Broadcast end of lane 0 to entire lane 1
|
||||
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
|
||||
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
|
||||
}
|
||||
prefix_byshift{op, shl0}
|
||||
}
|
||||
|
||||
@ -4,6 +4,8 @@ include './f64'
|
||||
include './mask'
|
||||
include './bitops'
|
||||
|
||||
def avx2 = hasarch{'AVX2'}
|
||||
|
||||
# Group l (power of 2) elements into paired groups of length o
|
||||
# e.g. pairs{2, iota{8}} = {{0,1,4,5}, {2,3,6,7}}
|
||||
def pairs{o, x} = {
|
||||
@ -22,31 +24,82 @@ def permute_pass{o, x} = {
|
||||
merge{h{0,2}, h{1,3}}
|
||||
}
|
||||
def unpack_to{f, l, x} = {
|
||||
def pass = if (f) permute_pass else unpack_pass
|
||||
def pass = if (avx2 and f) permute_pass else unpack_pass
|
||||
pass{l, if (l==1) x else unpack_to{0, l/2, x}}
|
||||
}
|
||||
# Last pass for square kernel packed in halves
|
||||
def shuf_pass{x} = each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
|
||||
def halved_pass{n, x} = {
|
||||
if (not avx2) unpack_pass{n/2, x}
|
||||
else each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
|
||||
}
|
||||
|
||||
# Square kernel where width is a full vector
|
||||
def transpose_square{VT, l, x if hasarch{'AVX2'}} = unpack_to{1, l/2, x}
|
||||
def transpose_square{VT, l, x if avx2} = unpack_to{1, l/2, x}
|
||||
|
||||
def load2{a:*T, b:*T} = pair{load{a}, load{b}}
|
||||
def store2{a:*T, b:*T, v:T2 if w128i{T} and w256{T2}} = {
|
||||
each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
||||
def load2{a:*T, b:*T} = match (width{T}) {
|
||||
{64} => {
|
||||
def v = each{{p}=>loadLow{*[2]u64~~p, 64}, tup{a,b}}
|
||||
n_d{T}~~zip{...v, 0}
|
||||
}
|
||||
{128} => pair{load{a}, load{b}}
|
||||
}
|
||||
def load_k {VT, src, l, w if w256{VT}} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if w256{VT}} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
||||
def load_k {VT, src, l, w if w128{VT}} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if w128{VT}} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||
def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||
{ 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}}
|
||||
{128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
||||
}
|
||||
def store1of2{a:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||
{ 64} => storeLow{*u64~~a, 64, [2]u64~~v}
|
||||
{128} => store{a, 0, T~~half{v,0}}
|
||||
}
|
||||
def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||
def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
||||
def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||
def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||
|
||||
# Transpose kernel of size kw,kh in size w,h array
|
||||
def kernel{src:*T, dst:*T, kw, kh, w, h} = {
|
||||
def n = (kw*kh*width{T}) / 256 # Number of vectors
|
||||
def kernel_part{part_w}{src:*T, dst:*T, kw, kh, w, h} = {
|
||||
def n = (kw*kh*width{T}) / arch_defvw # Number of vectors
|
||||
def xvs = load_k{[kw]T, src, n, w}
|
||||
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
|
||||
def rvs = if (n==kw) xt else shuf_pass{xt} # To kh by kh for packed square
|
||||
store_k{[kh]T, dst, rvs, n, h}
|
||||
def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square
|
||||
def stores = store_k{[kh]T, ..., h}
|
||||
if (same{part_w, 0}) {
|
||||
stores{dst, rvs, n}
|
||||
} else {
|
||||
# Write w results, kw/2 <= n < kw
|
||||
d := dst
|
||||
def vd = kw / n # Number of writes for each output vector (1 or 2)
|
||||
def store_slice{rv, len} = {
|
||||
stores{d, slice{rv,0,len}, len}
|
||||
d += len*vd*h
|
||||
}
|
||||
store_slice{rvs, n/2} # Unconditionally store first half
|
||||
rt := slice{rvs,n/2} # Remaining tail
|
||||
def wtail{b} = {
|
||||
if ((part_w & (vd*b)) != 0) {
|
||||
store_slice{rt, b}
|
||||
slice{rt,0,b} = slice{rt,b,2*b}
|
||||
}
|
||||
if (b>1) wtail{b/2}
|
||||
}
|
||||
wtail{n/4}
|
||||
if (vd>1 and (part_w & 1) != 0) store1of2{*[kh]T~~d, select{rt,0}}
|
||||
}
|
||||
}
|
||||
def kernel = kernel_part{0}
|
||||
|
||||
def kernel_part_h{part_h}{src:*T==i8, dst:*T, kw==16, kh==16, w, h} = {
|
||||
def n = (kw*kh*width{T}) / arch_defvw
|
||||
def VT = [kw]T
|
||||
off := part_h - kh/2
|
||||
def xvs = @unroll (i to n) { s := src + i*w; load2{*VT~~s, *VT~~(s+off*w)} }
|
||||
def rvs = halved_pass{n, unpack_to{0, n/2, xvs}}
|
||||
@unroll (j to 2) {
|
||||
def is = 2*iota{2} + j
|
||||
d := dst + j*off
|
||||
def store_q{v, i} = { store{*u64~~d, 0, extract{v,i}}; d += h }
|
||||
each{{r} => each{store_q{[4]u64~~r,.}, is}, rvs}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -153,37 +206,264 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
|
||||
}
|
||||
}
|
||||
|
||||
# Interleave n values of type T from x0 and x1 into r
|
||||
# Unzip 2*n values
|
||||
def uninterleave{r0:*T, r1:*T, xp:*T, n} = {
|
||||
@for (r0, r1 over i to n) {
|
||||
r0 = load{xp, i*2}; r1 = load{xp, i*2+1}
|
||||
}
|
||||
}
|
||||
# Zip n values of type T from each of x0 and x1 into r
|
||||
fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
||||
rp := *T~~r0
|
||||
@for (x0 in *T~~x0, x1 in *T~~x1 over i to n) {
|
||||
store{rp, i*2, x0}; store{rp, i*2+1, x1}
|
||||
}
|
||||
}
|
||||
|
||||
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||
# Scalar transpose defined in C
|
||||
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
||||
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
|
||||
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
if (hasarch{'AVX2'} and w>=k and h>=k) {
|
||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||
} else {
|
||||
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)
|
||||
else if (w==2 and w==ws) @for (r0 in rp, r1 in rp+hs over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} }
|
||||
else call_base{rp, xp, w, h}
|
||||
# SIMD implementations
|
||||
def uninterleave{r0:*T, r1:*T, xp:*T, n if has_simd and (not hasarch{'X86_64'} or width{T}>=32 or hasarch{'SSSE3'})} = {
|
||||
def l = arch_defvw / width{T}
|
||||
def V = [l]T
|
||||
rv0 := *V~~r0; rv1 := *V~~r1; xv := *V~~xp
|
||||
nv := n / l
|
||||
def uz = if (not hasarch{'X86_64'}) unzip else ({...xs} => {
|
||||
def reinterpret{V, xs if ktup{xs}} = each{~~{V,.}, xs}
|
||||
def q = tr_quads{arch_defvw/128}
|
||||
def k = flat_table{+, iota{2}, 2 * iota{64 / width{T}}}
|
||||
def px = each{shuf{., k}, xs}
|
||||
V~~each{q, zip128{...re_el{u64,V}~~px}}
|
||||
})
|
||||
@for (r0 in rv0, r1 in rv1 over i to nv) {
|
||||
tup{r0, r1} = uz{...each{load{xv+2*i, .}, iota{2}}}
|
||||
}
|
||||
if (n % l > 0) {
|
||||
xb := xv + 2*nv
|
||||
x0 := load{xb}
|
||||
x1 := V**0; if (n&(l/2) != 0) x1 = load{xb, 1}
|
||||
mask := maskOf{V, n%l}
|
||||
each{homMaskStoreF{., mask, .}, tup{rv0+nv,rv1+nv}, uz{x0,x1}}
|
||||
}
|
||||
}
|
||||
fn interleave{T if has_simd}(r:*void, x0:*void, x1:*void, n:u64) : void = {
|
||||
def l = arch_defvw / width{T}
|
||||
def V = [l]T
|
||||
xv0 := *V~~x0; xv1 := *V~~x1; rv := *V~~r
|
||||
nv := n / l
|
||||
def q = tr_quads{arch_defvw/128}
|
||||
@for (x0 in xv0, x1 in xv1 over i to nv) {
|
||||
each{store{rv+2*i, ., .}, iota{2}, zip128{q{x0},q{x1}}}
|
||||
}
|
||||
if (n % l > 0) {
|
||||
def xs = each{load{.,nv}, tup{xv0,xv1}}
|
||||
def get_r = zip128{...each{q, xs}, .}
|
||||
r0 := get_r{0}
|
||||
rb := rv + 2*nv
|
||||
nr := 2*n; m := nr%l
|
||||
mask := maskOf{V, m}
|
||||
if (nr&l == 0) {
|
||||
homMaskStoreF{rb, mask, r0}
|
||||
} else {
|
||||
store{rb, 0, r0}
|
||||
if (m > 0) homMaskStoreF{rb+1, mask, get_r{1}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def transpose{T, k} = transpose{T, k, k}
|
||||
# Utilities for kernels based on modular permutation
|
||||
def rotcol{xs, mg:I} = {
|
||||
def w = length{xs}
|
||||
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
|
||||
def vk = I**k; def m = (mg & vk) == vk
|
||||
def bl{x,y} = { x = homBlend{x,y,m} }
|
||||
x0 := select{xs, 0}
|
||||
def xord = select{xs, k*iota{w} % w}
|
||||
each{bl, xord, shiftleft{xord, tup{x0}}}
|
||||
}
|
||||
}
|
||||
def get_modperm_lane_shuf{c} = {
|
||||
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
||||
{x, i} => cross{shuf{16, x, i}, i}
|
||||
}
|
||||
def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
|
||||
|
||||
exportT{'simd_transpose', tup{
|
||||
transpose{i8 , 16},
|
||||
transpose{i16, 8, 16},
|
||||
transpose{i32, 8},
|
||||
transpose{i64, 4}
|
||||
}}
|
||||
def loop_fixed_height{xp, rp, w, k, st, kern} = {
|
||||
@for_mult_max{k, w-k} (i to w+(-w)%k) kern{xp+i, rp+i*st}
|
||||
}
|
||||
def loop_fixed_width{xp, rp, h, k, st, kern} = {
|
||||
@for_mult_max{k, h-k} (i to h+(-h)%k) kern{xp+i*st, rp+i}
|
||||
}
|
||||
|
||||
exportT{'interleave_fns', each{interleave, tup{i8, i16, i32, i64}}}
|
||||
# Transpose a contiguous kernel of width w*p from x to r with stride rst
|
||||
def modular_kernel{w,p if w%2==1 and 2%p==0}{xp:*T, rp0:*T, rst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def ih = iota{h}; def iw = iota{w}
|
||||
def I = [h]u8; def V = I
|
||||
def e = width{T} / 8
|
||||
# Load a shape h,w slice of x, but consider as shape w,h
|
||||
def xsp = each{load{*V~~xp, .}, iw}
|
||||
# Modular permutation of (reshaped argument) columns
|
||||
xs := select{xsp, find_index{h/e*iw % w, iw}}
|
||||
# Rotate each column by its index
|
||||
rotcol{reverse{xs}, make{I, (ih // e) % w}}
|
||||
# Modular permutation of rows, and write to result
|
||||
rp := rp0
|
||||
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}
|
||||
mi := I**e
|
||||
def perm = if (h==16) shuf else {
|
||||
def sh = get_modperm_lane_shuf{I**16}
|
||||
def q = tr_quads{p}; if (p>1) mp = q{mp}
|
||||
{x, i} => q{sh{x, i}}
|
||||
}
|
||||
def perm_store{x} = {
|
||||
match (p) {
|
||||
{1} => store{*V~~rp, 0, perm{x, mp}}
|
||||
{2} => { def U = [h/2]u8; store2{*U~~rp, *U~~(rp+w*rst), perm{x, mp}} }
|
||||
}
|
||||
rp += rst; mp += mi
|
||||
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
|
||||
}
|
||||
each{perm_store, xs}
|
||||
}
|
||||
def modular_kernel{2,2}{xp:*T, rp0:*T, rst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def V = [h]u8; def U = n_h{V}
|
||||
def ih = iota{h}%16; def e = width{T} / 8
|
||||
# Permutation to unzip by 4 within each lane
|
||||
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
|
||||
# Unzipping code for the resulting 4-byte units
|
||||
def {st, proc, zipx} = match (h) {
|
||||
{16} => tup{2, {x} => [4]f32~~x, {xs,i} => V~~zip128{...xs,i}}
|
||||
{32} => tup{1, shuf{[4]u64, ., 0,2,1,3}, {xs,i} => shuf{[4]f32, xs, i + tup{0,2,0,2}}}
|
||||
}
|
||||
def xsp = each{load{*V~~xp, .}, iota{2}}
|
||||
xs := each{proc, each{shuf{16, ., uz}, xsp}}
|
||||
@unroll (i to 2) {
|
||||
rp := rp0 + st*i*rst
|
||||
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
|
||||
}
|
||||
}
|
||||
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
||||
def p = if (wk%2) 1 else 2; def w = wk/p
|
||||
def vl = arch_defvw / (p*width{T})
|
||||
loop_fixed_width{xp, rp, h, vl, wk, modular_kernel{w,p}{., ., hs}}
|
||||
}
|
||||
def transpose_fixed_width{rp:*T, xp:*T, 2, h, hs} = {
|
||||
uninterleave{rp, rp+hs, xp, h}
|
||||
}
|
||||
|
||||
# Transpose a kernel of height w*p from x with stride xst to contiguous r
|
||||
# w and h are named for the result, not argument, to match modular_kernel
|
||||
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{xp0:*T, rp:*T, xst:(u64)} = {
|
||||
def h = arch_defvw / 8
|
||||
def ih = iota{h}; def iw = iota{w}
|
||||
def I = [h]u8; def V = I
|
||||
def e = width{T} / 8
|
||||
# Read rows, modular permutation on each
|
||||
def rotbit{x, l,m,h} = x%l + (x-x%l)*(h/m)%h + x//m*l
|
||||
def wi = w + 2 * ((w-1) + (w&2)) # Inverse mod 32
|
||||
def mpd = rotbit{ih%e + (ih - ih%e)%16*wi%h, e,e*p,h}
|
||||
mp := make{I, if (h==16 or p==1) mpd else rotbit{mpd, 8,16,h}}
|
||||
def rot_mp = {
|
||||
def rot_lane = shuf{16, ., make{I, (ih-e)%16}}
|
||||
def cross = if (h==16) ({x}=>x) else ^{make{I, 16*(ih%16<e)}, .}
|
||||
{mp} => cross{rot_lane{mp}}
|
||||
}
|
||||
def perm = if (h==16) shuf else {
|
||||
def sh = get_modperm_lane_shuf{I**16}
|
||||
def q = tr_quads{p}
|
||||
{x, i} => sh{q{x}, i}
|
||||
}
|
||||
xp := xp0
|
||||
xs := @collect (w) {
|
||||
x := match (p) {
|
||||
{1} => perm{load{*V~~xp, 0}, mp}
|
||||
{2} => { def U = [h/2]u8; perm{load2{*U~~xp, *U~~(xp+w*xst)}, mp} }
|
||||
}
|
||||
xp += xst; mp = rot_mp{mp}
|
||||
x
|
||||
}
|
||||
# Rotate each column by its index
|
||||
rotcol{xs, make{I, (ih // e) % w}}
|
||||
# Permute vectors and store
|
||||
each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}}
|
||||
}
|
||||
def modular_kernel_rev{2,2}{xp0:*T, rp:*T, xst:(u64)} = {
|
||||
def V = [arch_defvw / width{T}]T; def U = n_h{V}
|
||||
xl := @unroll (i to 2) {
|
||||
xp := xp0 + i*xst
|
||||
x := load2{*U~~xp, *U~~(xp + 2*xst)}
|
||||
if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}}
|
||||
}
|
||||
xs := unpack_typed{...unpack_typed{...xl}}
|
||||
each{store{*V~~rp, ., .}, iota{2}, each{~~{V,.},xs}}
|
||||
}
|
||||
def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = {
|
||||
def p = if (hk%2) 1 else 2; def h = hk/p
|
||||
def vl = arch_defvw / (p*width{T})
|
||||
loop_fixed_height{xp, rp, w, vl, hk, modular_kernel_rev{h,p}{., ., ws}}
|
||||
}
|
||||
def transpose_fixed_height{rp:*T, xp:*T, w, ws, 2} = {
|
||||
interleave{T}(*void~~rp, *void~~xp, *void~~(xp+ws), w)
|
||||
}
|
||||
|
||||
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||
rp:*T = *T~~r0
|
||||
xp:*T = *T~~x0
|
||||
def wT = width{T}
|
||||
def vl = arch_defvw / wT
|
||||
# Transposes with code dedicated to a particular width or height
|
||||
def try_fixed_dim{tr, l, lst, nl, l_max} = {
|
||||
def incl{l} = if (k>4) 1 else l!=4
|
||||
if (l<l_max and incl{l} and l==lst) {
|
||||
if (l == 2) { tr{2}; return{} }
|
||||
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
|
||||
if (has_blend and nl>=vl/2 and (l%2==0 or nl>=vl)) {
|
||||
def try{ls} = {
|
||||
def i = length{ls}>>1
|
||||
if (l < select{ls,i}) try{slice{ls, 0,i}} else try{slice{ls, i}}
|
||||
}
|
||||
def try{{lk}} = tr{lk}
|
||||
try{replicate{{i} => i<l_max and incl{i}, slice{iota{8},3}}}
|
||||
return{}
|
||||
}
|
||||
}
|
||||
}
|
||||
# Small width: fixed, or over-reading partial kernel
|
||||
def use_part_w = wT<=16 and k>max{4,8/wT}
|
||||
def w_max = 1 + (if (use_part_w) max{4, k/2-1} else 7)
|
||||
if (has_simd and w < max{w_max, k}) {
|
||||
try_fixed_dim{transpose_fixed_width {rp, xp, ., h, hs}, w, ws, h, w_max}
|
||||
if (use_part_w and h>=kh and w>=k/2 and w<k) {
|
||||
loop_fixed_width{xp, rp, h, kh, ws, kernel_part{w}{., ., k, kh, ws, hs}}
|
||||
return{}
|
||||
}
|
||||
}
|
||||
# Small height: fixed, or kernel with overlapping writes
|
||||
# Overlapping is slower than over-reading, so it's only used when needed
|
||||
if (has_simd and h < max{8, k}) {
|
||||
try_fixed_dim{transpose_fixed_height{rp, xp, w, ws, .}, h, hs, w, 8}
|
||||
if (k>8 and w>=k and h>=kh/2) {
|
||||
loop_fixed_height{xp, rp, w, k, hs, kernel_part_h{h}{., ., k, kh, ws, hs}}
|
||||
return{}
|
||||
}
|
||||
}
|
||||
# Scalar transpose defined in C
|
||||
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
||||
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
|
||||
# Full kernels
|
||||
# May have w<k or h<k if not has_blend, or >2D transpose with w!=ws or h!=hs
|
||||
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||
} else {
|
||||
call_base{rp, xp, w, h}
|
||||
}
|
||||
}
|
||||
|
||||
def transpose{T, k if knum{k}} = transpose{T, tup{k, k}}
|
||||
|
||||
def tr_types = tup{i8, i16, i32, i64}
|
||||
def tr_kernels = if (not avx2) tup{ 8, 8, 4, 0 }
|
||||
else tup{16, tup{8, 16}, 8, 4 }
|
||||
|
||||
exportT{'simd_transpose', each{transpose, tr_types, tr_kernels}}
|
||||
|
||||
exportT{'interleave_fns', each{interleave, tr_types}}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user