Merge pull request #123 from mlochbaum/modperm
CPU-sized select single column and transpose with modular permutations
This commit is contained in:
commit
163853439e
@ -1 +1 @@
|
|||||||
Subproject commit d432cb710911457169da5342d27ce8adffd5dd1a
|
Subproject commit 17c512727dbcb6d58a2adadac4661bc9c43920d2
|
||||||
@ -219,8 +219,8 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
|
|||||||
|
|
||||||
|
|
||||||
// fast special-case implementations
|
// fast special-case implementations
|
||||||
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
|
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf); // from select.c
|
||||||
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
|
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x
|
||||||
ur xr = RNK(x);
|
ur xr = RNK(x);
|
||||||
assert(xr>1 && k<xr);
|
assert(xr>1 && k<xr);
|
||||||
usz* xsh = SH(x);
|
usz* xsh = SH(x);
|
||||||
@ -228,58 +228,15 @@ static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind
|
|||||||
usz l = xsh[k];
|
usz l = xsh[k];
|
||||||
assert(0<=ind && ind<l);
|
assert(0<=ind && ind<l);
|
||||||
assert(cam*l*csz == IA(x));
|
assert(cam*l*csz == IA(x));
|
||||||
Arr* ra;
|
B r = select_cells_single(ind, x, cam, l, csz, leaf);
|
||||||
usz take = leaf? 1 : csz;
|
Arr* ra = a(r);
|
||||||
if (l==1 && take==csz) {
|
|
||||||
ra = cpyWithShape(incG(x));
|
|
||||||
arr_shErase(ra, 1);
|
|
||||||
} else {
|
|
||||||
u8 xe = TI(x,elType);
|
|
||||||
u8 ewl= elwBitLog(xe);
|
|
||||||
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
|
|
||||||
usz ria = cam*take;
|
|
||||||
if (xl>=7 || (xl<3 && xl>0)) { // generic case
|
|
||||||
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
|
|
||||||
usz jump = l * csz;
|
|
||||||
usz xi = take*ind;
|
|
||||||
usz ri = 0;
|
|
||||||
for (usz i = 0; i < cam; i++) {
|
|
||||||
mut_copyG(rm, ri, x, xi, take);
|
|
||||||
xi+= jump;
|
|
||||||
ri+= take;
|
|
||||||
}
|
|
||||||
ra = mut_fp(rm);
|
|
||||||
} else if (xe==el_B) {
|
|
||||||
assert(take == 1);
|
|
||||||
SGet(x)
|
|
||||||
HArr_p rp = m_harrUv(ria);
|
|
||||||
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
|
|
||||||
NOGC_E; ra = (Arr*)rp.c;
|
|
||||||
} else {
|
|
||||||
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
|
|
||||||
void* xp = tyany_ptr(x);
|
|
||||||
switch(xl) {
|
|
||||||
case 0:
|
|
||||||
#if SINGELI
|
|
||||||
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
|
|
||||||
else
|
|
||||||
#endif
|
|
||||||
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
|
|
||||||
break;
|
|
||||||
case 3: PLAINLOOP for (usz i=0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
|
|
||||||
case 4: PLAINLOOP for (usz i=0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
|
|
||||||
case 5: PLAINLOOP for (usz i=0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
|
|
||||||
case 6: PLAINLOOP for (usz i=0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
|
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
|
||||||
if (rsh) {
|
if (rsh) {
|
||||||
shcpy(rsh, xsh, k);
|
shcpy(rsh, xsh, k);
|
||||||
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
|
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
|
||||||
}
|
}
|
||||||
decG(x);
|
decG(x);
|
||||||
return taga(ra);
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores
|
static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores
|
||||||
|
|||||||
@ -40,8 +40,15 @@
|
|||||||
// Sparse initialization if 𝕨 is much smaller than 𝕩
|
// Sparse initialization if 𝕨 is much smaller than 𝕩
|
||||||
// COULD call Mark Firsts (∊) for very short 𝕨 to avoid allocation
|
// COULD call Mark Firsts (∊) for very short 𝕨 to avoid allocation
|
||||||
|
|
||||||
// Select Cells - inds⊸⊏⎉1 x
|
// Select Cells - inds⊸⊏⎉1 𝕩
|
||||||
// Squeeze indices if too wide for given x
|
// Squeeze indices if too wide for given 𝕩
|
||||||
|
// Single index: (also used for monadic ⊏˘ ⊣˝˘ ⊢˝˘)
|
||||||
|
// Selecting a column of bits:
|
||||||
|
// Row size <64: extract as with fold-cells
|
||||||
|
// Selecting a column of 1, 2, 4, or 8-byte elements:
|
||||||
|
// Short cells: pack vectors from 𝕩, or blend and permute
|
||||||
|
// Long cells: dedicated scalar loop for each type
|
||||||
|
// Otherwise, loop with mutable copy
|
||||||
// Boolean indices:
|
// Boolean indices:
|
||||||
// Short inds and short cells: Widen to i8
|
// Short inds and short cells: Widen to i8
|
||||||
// Otherwise: bitsel call per cell
|
// Otherwise: bitsel call per cell
|
||||||
@ -57,7 +64,7 @@
|
|||||||
// COULD generate full list of indices via arith
|
// COULD generate full list of indices via arith
|
||||||
// 1-element cells: use (≠inds)/⥊x after checking ∧´inds∊0‿¯1
|
// 1-element cells: use (≠inds)/⥊x after checking ∧´inds∊0‿¯1
|
||||||
// Used for ⌽⎉1
|
// Used for ⌽⎉1
|
||||||
// SHOULD use for atom⊸⊏⎉k, /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
|
// SHOULD use for /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
|
||||||
|
|
||||||
#include "../core.h"
|
#include "../core.h"
|
||||||
#include "../utils/talloc.h"
|
#include "../utils/talloc.h"
|
||||||
@ -575,6 +582,62 @@ static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same e
|
|||||||
|
|
||||||
B slash_c2(B, B, B);
|
B slash_c2(B, B, B);
|
||||||
B select_cells_base(B inds, B x0, ux csz, ux cam);
|
B select_cells_base(B inds, B x0, ux csz, ux cam);
|
||||||
|
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
|
||||||
|
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
|
||||||
|
|
||||||
|
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf) { // ⥊ ind {leaf? <∘⊑; ⊏}˘ cam‿l‿csz ⥊ x
|
||||||
|
usz take = leaf? 1 : csz;
|
||||||
|
Arr* ra;
|
||||||
|
if (l==1 && take==csz) {
|
||||||
|
ra = cpyWithShape(incG(x));
|
||||||
|
arr_shErase(ra, 1);
|
||||||
|
} else {
|
||||||
|
u8 xe = TI(x,elType);
|
||||||
|
u8 ewl= elwBitLog(xe);
|
||||||
|
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
|
||||||
|
usz ria = cam*take;
|
||||||
|
if (xl>=7 || (xl<3 && xl>0)) { // generic case
|
||||||
|
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
|
||||||
|
usz jump = l * csz;
|
||||||
|
usz xi = take*ind;
|
||||||
|
usz ri = 0;
|
||||||
|
for (usz i = 0; i < cam; i++) {
|
||||||
|
mut_copyG(rm, ri, x, xi, take);
|
||||||
|
xi+= jump;
|
||||||
|
ri+= take;
|
||||||
|
}
|
||||||
|
ra = mut_fp(rm);
|
||||||
|
} else if (xe==el_B) {
|
||||||
|
assert(take == 1);
|
||||||
|
SGet(x)
|
||||||
|
HArr_p rp = m_harrUv(ria);
|
||||||
|
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
|
||||||
|
NOGC_E; ra = (Arr*)rp.c;
|
||||||
|
} else {
|
||||||
|
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
|
||||||
|
void* xp = tyany_ptr(x);
|
||||||
|
if (xl == 0) {
|
||||||
|
#if SINGELI
|
||||||
|
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
|
||||||
|
else
|
||||||
|
#endif
|
||||||
|
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
|
||||||
|
} else {
|
||||||
|
usz i0 = 0;
|
||||||
|
#if SINGELI
|
||||||
|
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
|
||||||
|
#endif
|
||||||
|
switch(xl) { default: UD;
|
||||||
|
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
|
||||||
|
case 4: PLAINLOOP for (usz i=i0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
|
||||||
|
case 5: PLAINLOOP for (usz i=i0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
|
||||||
|
case 6: PLAINLOOP for (usz i=i0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return taga(ra);
|
||||||
|
}
|
||||||
|
|
||||||
#define CLZC(X) (64-(CLZ((u64)(X))))
|
#define CLZC(X) (64-(CLZ((u64)(X))))
|
||||||
|
|
||||||
@ -851,6 +914,11 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸
|
|||||||
ux in = IA(inds);
|
ux in = IA(inds);
|
||||||
if (in == 0) return taga(emptyArr(x, 1));
|
if (in == 0) return taga(emptyArr(x, 1));
|
||||||
u8 ie = TI(inds,elType);
|
u8 ie = TI(inds,elType);
|
||||||
|
if (in == 1) {
|
||||||
|
B w = IGetU(inds,0); if (!isF64(w)) goto generic;
|
||||||
|
B r = select_cells_single(WRAP(o2i64(w), csz, thrF("⊏: Indexing out-of-bounds (%R∊𝕨, %s≡≠𝕩)", w, csz)), x, cam, csz, 1, false);
|
||||||
|
decG(x); decG(inds); return r;
|
||||||
|
}
|
||||||
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
|
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
|
||||||
inds = num_squeeze(inds);
|
inds = num_squeeze(inds);
|
||||||
ie = TI(inds,elType);
|
ie = TI(inds,elType);
|
||||||
|
|||||||
@ -1,18 +1,22 @@
|
|||||||
// Transpose and Reorder Axes (⍉)
|
// Transpose and Reorder Axes (⍉)
|
||||||
|
|
||||||
// Transpose
|
// Transpose
|
||||||
// One length-2 axis: dedicated code
|
|
||||||
// Boolean: pdep or emulation for height 2; pext for width 2
|
|
||||||
// SHOULD use a generic implementation if BMI2 not present
|
|
||||||
// SHOULD optimize other short lengths with pdep/pext and shuffles
|
|
||||||
// Boolean 𝕩: convert to integer
|
// Boolean 𝕩: convert to integer
|
||||||
|
// pdep or emulation for height 2; pext for width 2
|
||||||
|
// SHOULD switch to shuffles and modular permutation
|
||||||
// SHOULD have bit matrix transpose kernel
|
// SHOULD have bit matrix transpose kernel
|
||||||
// CPU sizes: native or SIMD code
|
// CPU sizes: native or SIMD code
|
||||||
// Large SIMD kernels used when they fit, overlapping for odd sizes
|
// Large SIMD kernels used when they fit, overlapping for odd sizes
|
||||||
// i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
|
// SSE, NEON i8: 8×8 ; i16: 8×8; i32: 4×4; f64: scalar
|
||||||
// COULD use half-width or smaller kernels to improve odd sizes
|
// AVX2 i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
|
||||||
// Scalar transpose or loop used for overhang of 1
|
// COULD use half-width or smaller kernels to improve odd sizes
|
||||||
// SHOULD add NEON
|
// Scalar transpose or loop used for overhang of 1
|
||||||
|
// Partial kernels for at least half-kernel height/width
|
||||||
|
// Short width: over-read, then skip some writes
|
||||||
|
// Short height, i8 AVX2 only: overlap input rows and output words
|
||||||
|
// Smaller heights and widths: dedicated kernels
|
||||||
|
// Zipping, unzipping, shuffling for powers of 2
|
||||||
|
// Modular permutation for odd numbers, possibly times 2
|
||||||
|
|
||||||
// Reorder Axes
|
// Reorder Axes
|
||||||
// If 𝕨 indicates the identity permutation, return 𝕩
|
// If 𝕨 indicates the identity permutation, return 𝕩
|
||||||
|
|||||||
@ -74,6 +74,189 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
|
|||||||
export{'si_sum_f64', fold_assoc_0{f64,+}}
|
export{'si_sum_f64', fold_assoc_0{f64,+}}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_column_pow2{T, x0, r0, nv, k} = {
|
||||||
|
def V = [arch_defvw / width{T}]T
|
||||||
|
xv := *V~~x0
|
||||||
|
@for (r in *V~~r0 over i to nv) {
|
||||||
|
xs := each{load{xv, .}, iota{k}}
|
||||||
|
def unzip0{w} = if (not hasarch{'X86_64'}) {
|
||||||
|
unzip{..., 0} # Sane instruction set
|
||||||
|
} else {
|
||||||
|
if (w <= 16) {
|
||||||
|
# Pack instructions
|
||||||
|
m := make{V, - (iota{vcount{V}}%k == 0)}
|
||||||
|
xs = each{&{m, .}, xs} # Mask off high bits
|
||||||
|
def D = el_m{V}
|
||||||
|
{a, b} => packQ{D~~a, D~~b}
|
||||||
|
} else {
|
||||||
|
# Two-vector shuffles
|
||||||
|
# Could also be used for 1/2-byte with ending gap >= 4 bytes,
|
||||||
|
# less instructions but it doesn't seem faster
|
||||||
|
def c = 128/w
|
||||||
|
def sh = shuf{[c]ty_f{w}, ., 2*iota{c} % c}
|
||||||
|
{...ab} => sh{ab}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (not hasarch{'X86_64'} or T != u16 or hasarch{'SSE4.1'}) {
|
||||||
|
r = tree_fold{unzip0{width{T}}, xs}
|
||||||
|
} else {
|
||||||
|
# No unsigned saturation: sign-extend then use unsigned
|
||||||
|
def D = [4]i32
|
||||||
|
def f = tree_fold{unzip0{32}, .}
|
||||||
|
def proc{hx} = {
|
||||||
|
ri := D~~f{hx}
|
||||||
|
top := D**(1<<15); m := D**(1<<16 - 1)
|
||||||
|
(ri & m) | (D~~(ri&top == top) &~ m)
|
||||||
|
}
|
||||||
|
r = V~~packQ{...each{proc, split{k/2, xs}}}
|
||||||
|
}
|
||||||
|
if (width{V} > 128) { # Lane axis wasn't packed, need to shuffle to bottom
|
||||||
|
def tr{E,a, r} = shuf{[1<<a]E, r, tr_iota{shiftright{a-1, iota{a}}}}
|
||||||
|
def lc = k > 4
|
||||||
|
if (lc) r = tr{u64,2, r}
|
||||||
|
r = tr{ty_u{128/k}, lb{k} + (not lc), r}
|
||||||
|
}
|
||||||
|
xv += k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_column_modperm{x0, r0, nv, l, el, vl} = {
|
||||||
|
# Build modular permutations
|
||||||
|
def V = [vl]u8
|
||||||
|
def H = [vl/2]u16
|
||||||
|
p2 := ctz{l}; l >>= p2 # Decompose into l<<p2 with odd l
|
||||||
|
e := p2 + promote{ux,el} # Absorb into element size for most computation
|
||||||
|
l8 := cast_i{u8, l}
|
||||||
|
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
|
||||||
|
elo:= V**(u8~~1<<e - 1)
|
||||||
|
ie := iota{V} & elo
|
||||||
|
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
|
||||||
|
def mu16{k} = {
|
||||||
|
k16 := H ** k
|
||||||
|
prd := shuf{V~~(kmul * k16), 0,0}
|
||||||
|
if (e == 0) prd += V~~(k16 << 8)
|
||||||
|
(prd & V**(vl-1)) + ie
|
||||||
|
}
|
||||||
|
si := mu16{l8}
|
||||||
|
sii := mu16{li}
|
||||||
|
def swap_ms = if (vl == 16) ({x}=>x) else {
|
||||||
|
ms := (V**16 & sii) == (V**16 &~ iota{V})
|
||||||
|
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Blend masks
|
||||||
|
def mg = { # Iteration i should select where mg == V**i
|
||||||
|
ss := (si < V**(l8<<e)) & (ie == V**0)
|
||||||
|
vs := V**0xff - scan_assoc_id0{+}{ss}
|
||||||
|
swap_ms{shuf{[16]u8, vs, sii}}
|
||||||
|
}
|
||||||
|
mgo := mg - V**(l8 & 3)
|
||||||
|
mgm := (mgo - V**1) & V**3
|
||||||
|
m4s := @collect (i to 3) mgm == V**i
|
||||||
|
|
||||||
|
# Main loop
|
||||||
|
def loop{output} = {
|
||||||
|
xv := *V~~x0
|
||||||
|
rv := *V~~r0
|
||||||
|
# Each iteration handles l vectors from x
|
||||||
|
@for (i to nv<<p2) {
|
||||||
|
# 1 or 3 initial vectors
|
||||||
|
r := load{xv,0}; ++xv
|
||||||
|
if ((l & 2) != 0) {
|
||||||
|
def {m0, _, m2} = m4s
|
||||||
|
re := homBlend{load{xv,0}, load{xv,1}, m2}
|
||||||
|
r = homBlend{re, r, m0}
|
||||||
|
xv += 2
|
||||||
|
}
|
||||||
|
# Then the rest in groups of 4
|
||||||
|
mh := mgo
|
||||||
|
@for (l/4) {
|
||||||
|
{l0, ...ls} := each{load{xv,.}, iota{4}}
|
||||||
|
r4 := fold{homBlend, l0, ls, m4s}
|
||||||
|
r = homBlend{r, r4, mh < V**4}
|
||||||
|
mh -= V**4; xv += 4
|
||||||
|
}
|
||||||
|
def write{r} = {
|
||||||
|
store{rv, 0, shuf{[16]u8, r, si}}
|
||||||
|
++rv
|
||||||
|
}
|
||||||
|
output{r, i, write}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle odd and even strides separately
|
||||||
|
if (p2 == 0) {
|
||||||
|
loop{{r, i, write} => write{swap_ms{r}}}
|
||||||
|
} else {
|
||||||
|
# Store results
|
||||||
|
def add_res = {
|
||||||
|
ra := V**0 # Accumulator
|
||||||
|
i := make{V, iota{vl}%16}
|
||||||
|
o := V**(u8~~1<<el)
|
||||||
|
bl := i & elo < o
|
||||||
|
sh := i - o
|
||||||
|
{r} => { ra = homBlend{shuf{[16]u8, ra, sh}, r, bl} }
|
||||||
|
}
|
||||||
|
il := make{V, iota{vl} % 16}
|
||||||
|
# Shuffle to undo interleaving of add_res
|
||||||
|
def __shr{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x >> sh)
|
||||||
|
def __shl{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x << sh)
|
||||||
|
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8,-sh}
|
||||||
|
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8, sh}
|
||||||
|
def uz_lane = {
|
||||||
|
l := V**(u8~~1<<el - 1)
|
||||||
|
h := V**(16 - u8~~16>>p2)
|
||||||
|
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
|
||||||
|
dz |= (il &~ (l | h))<<p2 # middle->high
|
||||||
|
shuf{[16]u8, ., dz}
|
||||||
|
}
|
||||||
|
# Adjust modular permutation to apply after unzipping
|
||||||
|
si = uz_lane{(si & V**(vl - u8~~1<<e)) >> p2}
|
||||||
|
si ^= il &~ V**((16 - u8~~1<<e) >> p2)
|
||||||
|
# Cross-lane follow-up
|
||||||
|
def cross = if (vl == 16) { {x}=>x } else {
|
||||||
|
si = shuf{[4]u64, si, 0,1,0,1}
|
||||||
|
assert{p2 <= 2}
|
||||||
|
cr := make{[8]u32, tr_iota{0,2,1}}
|
||||||
|
if (p2 > 1) cr = make{[8]u32, tr_iota{2,0,1}}
|
||||||
|
shuf{[8]u32, ., cr}
|
||||||
|
}
|
||||||
|
# Run, writing every 1<<p2 steps
|
||||||
|
plo := usz~~1<<p2 - 1
|
||||||
|
loop{{r, i, write} => {
|
||||||
|
def ra = add_res{r}
|
||||||
|
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Select one element out of every l, element width 1<<el bytes
|
||||||
|
# Maximum of n result values, return actual number written
|
||||||
|
fn extract_column(x0:*void, r0:*void, n:usz, l:usz, el:u8) : usz = {
|
||||||
|
n <<= el
|
||||||
|
def vl = arch_defvw / 8
|
||||||
|
def thr = min{vl+2, 20}
|
||||||
|
if ((not has_simd) or n < vl or l > usz~~thr>>el or l<<el >= thr) return{0}
|
||||||
|
nv := n / vl
|
||||||
|
if (has_simd and (l & (l-1)) == 0) {
|
||||||
|
def try_unzip{T, k} = if (k < thr and l == k) {
|
||||||
|
extract_column_pow2{T, x0, r0, nv, k}
|
||||||
|
goto{'ret'}
|
||||||
|
}
|
||||||
|
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
|
||||||
|
@unroll (ek to 4) if (el == ek) {
|
||||||
|
def T = ty_u{8<<ek}
|
||||||
|
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
|
||||||
|
}
|
||||||
|
return{0}
|
||||||
|
setlabel{'ret'}
|
||||||
|
} else if (hasarch{'SSE4.1'} or hasarch{'AARCH64'}) {
|
||||||
|
extract_column_modperm{x0, r0, nv, l, el, vl}
|
||||||
|
} else return{0}
|
||||||
|
(usz~~vl>>el) * nv
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Short-row boolean folds: main challenge is bit packing
|
# Short-row boolean folds: main challenge is bit packing
|
||||||
def fold_rows_bit_lt64{
|
def fold_rows_bit_lt64{
|
||||||
op, run_loop2, run_loop4, pext_res, mult_in,
|
op, run_loop2, run_loop4, pext_res, mult_in,
|
||||||
@ -184,7 +367,7 @@ def fold_rows_bit_lt64{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn select_rows_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
|
fn extract_column_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
|
||||||
assert{l < 64}; assert{o < l} # Row length, and offset within row
|
assert{l < 64}; assert{o < l} # Row length, and offset within row
|
||||||
def run_loop2{loop} = loop{{a,b} => a>>o}
|
def run_loop2{loop} = loop{{a,b} => a>>o}
|
||||||
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
|
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
|
||||||
@ -329,4 +512,5 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
|
|||||||
}
|
}
|
||||||
export{'si_xor_rows_bit', xor_rows_bit}
|
export{'si_xor_rows_bit', xor_rows_bit}
|
||||||
export{'si_or_rows_bit', or_rows_bit}
|
export{'si_or_rows_bit', or_rows_bit}
|
||||||
export{'si_select_cells_bit_lt64', select_rows_bit_lt64}
|
export{'si_select_cells_bit_lt64', extract_column_bit_lt64}
|
||||||
|
export{'si_select_cells_byte', extract_column}
|
||||||
|
|||||||
@ -693,7 +693,7 @@ def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
|
|||||||
def fixed_loop{k} = {
|
def fixed_loop{k} = {
|
||||||
assert{wv == k}
|
assert{wv == k}
|
||||||
while (1) {
|
while (1) {
|
||||||
# e.g. 01234567 to 05316427 on each byte for k==3, ew==8
|
# e.g. 01234567 to 03614725 on each byte for k==3, ew==8
|
||||||
xv := get_perm_x{}
|
xv := get_perm_x{}
|
||||||
# Overhang from previous 64-bit elements
|
# Overhang from previous 64-bit elements
|
||||||
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
||||||
|
|||||||
@ -74,16 +74,7 @@ export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', sca
|
|||||||
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
|
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
|
||||||
|
|
||||||
# Assumes identity is 0
|
# Assumes identity is 0
|
||||||
def scan_assoc{op} = {
|
def scan_plus = scan_assoc_id0{+}
|
||||||
def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise
|
|
||||||
def shl0{v:V, k==128 if hasarch{'AVX2'}} = {
|
|
||||||
# Broadcast end of lane 0 to entire lane 1
|
|
||||||
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
|
|
||||||
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
|
|
||||||
}
|
|
||||||
prefix_byshift{op, shl0}
|
|
||||||
}
|
|
||||||
def scan_plus = scan_assoc{+}
|
|
||||||
|
|
||||||
# Associative scan
|
# Associative scan
|
||||||
def scan_assoc_0 = scan_scal
|
def scan_assoc_0 = scan_scal
|
||||||
|
|||||||
@ -66,3 +66,13 @@ def make_scan_idem{(f64), op, up} = {
|
|||||||
sc
|
sc
|
||||||
}
|
}
|
||||||
def make_scan_idem{T, op} = make_scan_idem{T, op, 1}
|
def make_scan_idem{T, op} = make_scan_idem{T, op, 1}
|
||||||
|
|
||||||
|
def scan_assoc_id0{op} = {
|
||||||
|
def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise
|
||||||
|
def shl0{v:V, k==128 if hasarch{'AVX2'}} = {
|
||||||
|
# Broadcast end of lane 0 to entire lane 1
|
||||||
|
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
|
||||||
|
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
|
||||||
|
}
|
||||||
|
prefix_byshift{op, shl0}
|
||||||
|
}
|
||||||
|
|||||||
@ -4,6 +4,8 @@ include './f64'
|
|||||||
include './mask'
|
include './mask'
|
||||||
include './bitops'
|
include './bitops'
|
||||||
|
|
||||||
|
def avx2 = hasarch{'AVX2'}
|
||||||
|
|
||||||
# Group l (power of 2) elements into paired groups of length o
|
# Group l (power of 2) elements into paired groups of length o
|
||||||
# e.g. pairs{2, iota{8}} = {{0,1,4,5}, {2,3,6,7}}
|
# e.g. pairs{2, iota{8}} = {{0,1,4,5}, {2,3,6,7}}
|
||||||
def pairs{o, x} = {
|
def pairs{o, x} = {
|
||||||
@ -22,31 +24,82 @@ def permute_pass{o, x} = {
|
|||||||
merge{h{0,2}, h{1,3}}
|
merge{h{0,2}, h{1,3}}
|
||||||
}
|
}
|
||||||
def unpack_to{f, l, x} = {
|
def unpack_to{f, l, x} = {
|
||||||
def pass = if (f) permute_pass else unpack_pass
|
def pass = if (avx2 and f) permute_pass else unpack_pass
|
||||||
pass{l, if (l==1) x else unpack_to{0, l/2, x}}
|
pass{l, if (l==1) x else unpack_to{0, l/2, x}}
|
||||||
}
|
}
|
||||||
# Last pass for square kernel packed in halves
|
# Last pass for square kernel packed in halves
|
||||||
def shuf_pass{x} = each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
|
def halved_pass{n, x} = {
|
||||||
|
if (not avx2) unpack_pass{n/2, x}
|
||||||
|
else each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
|
||||||
|
}
|
||||||
|
|
||||||
# Square kernel where width is a full vector
|
# Square kernel where width is a full vector
|
||||||
def transpose_square{VT, l, x if hasarch{'AVX2'}} = unpack_to{1, l/2, x}
|
def transpose_square{VT, l, x if avx2} = unpack_to{1, l/2, x}
|
||||||
|
|
||||||
def load2{a:*T, b:*T} = pair{load{a}, load{b}}
|
def load2{a:*T, b:*T} = match (width{T}) {
|
||||||
def store2{a:*T, b:*T, v:T2 if w128i{T} and w256{T2}} = {
|
{64} => {
|
||||||
each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
def v = each{{p}=>loadLow{*[2]u64~~p, 64}, tup{a,b}}
|
||||||
|
n_d{T}~~zip{...v, 0}
|
||||||
|
}
|
||||||
|
{128} => pair{load{a}, load{b}}
|
||||||
}
|
}
|
||||||
def load_k {VT, src, l, w if w256{VT}} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||||
def store_k{VT, dst, x, l, h if w256{VT}} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
{ 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}}
|
||||||
def load_k {VT, src, l, w if w128{VT}} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
{128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
|
||||||
def store_k{VT, dst, x, l, h if w128{VT}} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
}
|
||||||
|
def store1of2{a:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
|
||||||
|
{ 64} => storeLow{*u64~~a, 64, [2]u64~~v}
|
||||||
|
{128} => store{a, 0, T~~half{v,0}}
|
||||||
|
}
|
||||||
|
def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
|
||||||
|
def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
|
||||||
|
def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
|
||||||
|
def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
|
||||||
|
|
||||||
# Transpose kernel of size kw,kh in size w,h array
|
# Transpose kernel of size kw,kh in size w,h array
|
||||||
def kernel{src:*T, dst:*T, kw, kh, w, h} = {
|
def kernel_part{part_w}{src:*T, dst:*T, kw, kh, w, h} = {
|
||||||
def n = (kw*kh*width{T}) / 256 # Number of vectors
|
def n = (kw*kh*width{T}) / arch_defvw # Number of vectors
|
||||||
def xvs = load_k{[kw]T, src, n, w}
|
def xvs = load_k{[kw]T, src, n, w}
|
||||||
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
|
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
|
||||||
def rvs = if (n==kw) xt else shuf_pass{xt} # To kh by kh for packed square
|
def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square
|
||||||
store_k{[kh]T, dst, rvs, n, h}
|
def stores = store_k{[kh]T, ..., h}
|
||||||
|
if (same{part_w, 0}) {
|
||||||
|
stores{dst, rvs, n}
|
||||||
|
} else {
|
||||||
|
# Write w results, kw/2 <= n < kw
|
||||||
|
d := dst
|
||||||
|
def vd = kw / n # Number of writes for each output vector (1 or 2)
|
||||||
|
def store_slice{rv, len} = {
|
||||||
|
stores{d, slice{rv,0,len}, len}
|
||||||
|
d += len*vd*h
|
||||||
|
}
|
||||||
|
store_slice{rvs, n/2} # Unconditionally store first half
|
||||||
|
rt := slice{rvs,n/2} # Remaining tail
|
||||||
|
def wtail{b} = {
|
||||||
|
if ((part_w & (vd*b)) != 0) {
|
||||||
|
store_slice{rt, b}
|
||||||
|
slice{rt,0,b} = slice{rt,b,2*b}
|
||||||
|
}
|
||||||
|
if (b>1) wtail{b/2}
|
||||||
|
}
|
||||||
|
wtail{n/4}
|
||||||
|
if (vd>1 and (part_w & 1) != 0) store1of2{*[kh]T~~d, select{rt,0}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
def kernel = kernel_part{0}
|
||||||
|
|
||||||
|
def kernel_part_h{part_h}{src:*T==i8, dst:*T, kw==16, kh==16, w, h} = {
|
||||||
|
def n = (kw*kh*width{T}) / arch_defvw
|
||||||
|
def VT = [kw]T
|
||||||
|
off := part_h - kh/2
|
||||||
|
def xvs = @unroll (i to n) { s := src + i*w; load2{*VT~~s, *VT~~(s+off*w)} }
|
||||||
|
def rvs = halved_pass{n, unpack_to{0, n/2, xvs}}
|
||||||
|
@unroll (j to 2) {
|
||||||
|
def is = 2*iota{2} + j
|
||||||
|
d := dst + j*off
|
||||||
|
def store_q{v, i} = { store{*u64~~d, 0, extract{v,i}}; d += h }
|
||||||
|
each{{r} => each{store_q{[4]u64~~r,.}, is}, rvs}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -153,37 +206,264 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Interleave n values of type T from x0 and x1 into r
|
# Unzip 2*n values
|
||||||
|
def uninterleave{r0:*T, r1:*T, xp:*T, n} = {
|
||||||
|
@for (r0, r1 over i to n) {
|
||||||
|
r0 = load{xp, i*2}; r1 = load{xp, i*2+1}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Zip n values of type T from each of x0 and x1 into r
|
||||||
fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
|
||||||
rp := *T~~r0
|
rp := *T~~r0
|
||||||
@for (x0 in *T~~x0, x1 in *T~~x1 over i to n) {
|
@for (x0 in *T~~x0, x1 in *T~~x1 over i to n) {
|
||||||
store{rp, i*2, x0}; store{rp, i*2+1, x1}
|
store{rp, i*2, x0}; store{rp, i*2+1, x1}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
# SIMD implementations
|
||||||
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
def uninterleave{r0:*T, r1:*T, xp:*T, n if has_simd and (not hasarch{'X86_64'} or width{T}>=32 or hasarch{'SSSE3'})} = {
|
||||||
# Scalar transpose defined in C
|
def l = arch_defvw / width{T}
|
||||||
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
def V = [l]T
|
||||||
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
|
rv0 := *V~~r0; rv1 := *V~~r1; xv := *V~~xp
|
||||||
|
nv := n / l
|
||||||
rp:*T = *T~~r0
|
def uz = if (not hasarch{'X86_64'}) unzip else ({...xs} => {
|
||||||
xp:*T = *T~~x0
|
def reinterpret{V, xs if ktup{xs}} = each{~~{V,.}, xs}
|
||||||
if (hasarch{'AVX2'} and w>=k and h>=k) {
|
def q = tr_quads{arch_defvw/128}
|
||||||
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
def k = flat_table{+, iota{2}, 2 * iota{64 / width{T}}}
|
||||||
} else {
|
def px = each{shuf{., k}, xs}
|
||||||
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)
|
V~~each{q, zip128{...re_el{u64,V}~~px}}
|
||||||
else if (w==2 and w==ws) @for (r0 in rp, r1 in rp+hs over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} }
|
})
|
||||||
else call_base{rp, xp, w, h}
|
@for (r0 in rv0, r1 in rv1 over i to nv) {
|
||||||
|
tup{r0, r1} = uz{...each{load{xv+2*i, .}, iota{2}}}
|
||||||
|
}
|
||||||
|
if (n % l > 0) {
|
||||||
|
xb := xv + 2*nv
|
||||||
|
x0 := load{xb}
|
||||||
|
x1 := V**0; if (n&(l/2) != 0) x1 = load{xb, 1}
|
||||||
|
mask := maskOf{V, n%l}
|
||||||
|
each{homMaskStoreF{., mask, .}, tup{rv0+nv,rv1+nv}, uz{x0,x1}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn interleave{T if has_simd}(r:*void, x0:*void, x1:*void, n:u64) : void = {
|
||||||
|
def l = arch_defvw / width{T}
|
||||||
|
def V = [l]T
|
||||||
|
xv0 := *V~~x0; xv1 := *V~~x1; rv := *V~~r
|
||||||
|
nv := n / l
|
||||||
|
def q = tr_quads{arch_defvw/128}
|
||||||
|
@for (x0 in xv0, x1 in xv1 over i to nv) {
|
||||||
|
each{store{rv+2*i, ., .}, iota{2}, zip128{q{x0},q{x1}}}
|
||||||
|
}
|
||||||
|
if (n % l > 0) {
|
||||||
|
def xs = each{load{.,nv}, tup{xv0,xv1}}
|
||||||
|
def get_r = zip128{...each{q, xs}, .}
|
||||||
|
r0 := get_r{0}
|
||||||
|
rb := rv + 2*nv
|
||||||
|
nr := 2*n; m := nr%l
|
||||||
|
mask := maskOf{V, m}
|
||||||
|
if (nr&l == 0) {
|
||||||
|
homMaskStoreF{rb, mask, r0}
|
||||||
|
} else {
|
||||||
|
store{rb, 0, r0}
|
||||||
|
if (m > 0) homMaskStoreF{rb+1, mask, get_r{1}}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def transpose{T, k} = transpose{T, k, k}
|
# Utilities for kernels based on modular permutation
|
||||||
|
def rotcol{xs, mg:I} = {
|
||||||
|
def w = length{xs}
|
||||||
|
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
|
||||||
|
def vk = I**k; def m = (mg & vk) == vk
|
||||||
|
def bl{x,y} = { x = homBlend{x,y,m} }
|
||||||
|
x0 := select{xs, 0}
|
||||||
|
def xord = select{xs, k*iota{w} % w}
|
||||||
|
each{bl, xord, shiftleft{xord, tup{x0}}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
def get_modperm_lane_shuf{c} = {
|
||||||
|
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
|
||||||
|
{x, i} => cross{shuf{16, x, i}, i}
|
||||||
|
}
|
||||||
|
def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
|
||||||
|
|
||||||
exportT{'simd_transpose', tup{
|
def loop_fixed_height{xp, rp, w, k, st, kern} = {
|
||||||
transpose{i8 , 16},
|
@for_mult_max{k, w-k} (i to w+(-w)%k) kern{xp+i, rp+i*st}
|
||||||
transpose{i16, 8, 16},
|
}
|
||||||
transpose{i32, 8},
|
def loop_fixed_width{xp, rp, h, k, st, kern} = {
|
||||||
transpose{i64, 4}
|
@for_mult_max{k, h-k} (i to h+(-h)%k) kern{xp+i*st, rp+i}
|
||||||
}}
|
}
|
||||||
|
|
||||||
exportT{'interleave_fns', each{interleave, tup{i8, i16, i32, i64}}}
|
# Transpose a contiguous kernel of width w*p from x to r with stride rst
|
||||||
|
def modular_kernel{w,p if w%2==1 and 2%p==0}{xp:*T, rp0:*T, rst:(u64)} = {
|
||||||
|
def h = arch_defvw / 8
|
||||||
|
def ih = iota{h}; def iw = iota{w}
|
||||||
|
def I = [h]u8; def V = I
|
||||||
|
def e = width{T} / 8
|
||||||
|
# Load a shape h,w slice of x, but consider as shape w,h
|
||||||
|
def xsp = each{load{*V~~xp, .}, iw}
|
||||||
|
# Modular permutation of (reshaped argument) columns
|
||||||
|
xs := select{xsp, find_index{h/e*iw % w, iw}}
|
||||||
|
# Rotate each column by its index
|
||||||
|
rotcol{reverse{xs}, make{I, (ih // e) % w}}
|
||||||
|
# Modular permutation of rows, and write to result
|
||||||
|
rp := rp0
|
||||||
|
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}
|
||||||
|
mi := I**e
|
||||||
|
def perm = if (h==16) shuf else {
|
||||||
|
def sh = get_modperm_lane_shuf{I**16}
|
||||||
|
def q = tr_quads{p}; if (p>1) mp = q{mp}
|
||||||
|
{x, i} => q{sh{x, i}}
|
||||||
|
}
|
||||||
|
def perm_store{x} = {
|
||||||
|
match (p) {
|
||||||
|
{1} => store{*V~~rp, 0, perm{x, mp}}
|
||||||
|
{2} => { def U = [h/2]u8; store2{*U~~rp, *U~~(rp+w*rst), perm{x, mp}} }
|
||||||
|
}
|
||||||
|
rp += rst; mp += mi
|
||||||
|
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
|
||||||
|
}
|
||||||
|
each{perm_store, xs}
|
||||||
|
}
|
||||||
|
def modular_kernel{2,2}{xp:*T, rp0:*T, rst:(u64)} = {
|
||||||
|
def h = arch_defvw / 8
|
||||||
|
def V = [h]u8; def U = n_h{V}
|
||||||
|
def ih = iota{h}%16; def e = width{T} / 8
|
||||||
|
# Permutation to unzip by 4 within each lane
|
||||||
|
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
|
||||||
|
# Unzipping code for the resulting 4-byte units
|
||||||
|
def {st, proc, zipx} = match (h) {
|
||||||
|
{16} => tup{2, {x} => [4]f32~~x, {xs,i} => V~~zip128{...xs,i}}
|
||||||
|
{32} => tup{1, shuf{[4]u64, ., 0,2,1,3}, {xs,i} => shuf{[4]f32, xs, i + tup{0,2,0,2}}}
|
||||||
|
}
|
||||||
|
def xsp = each{load{*V~~xp, .}, iota{2}}
|
||||||
|
xs := each{proc, each{shuf{16, ., uz}, xsp}}
|
||||||
|
@unroll (i to 2) {
|
||||||
|
rp := rp0 + st*i*rst
|
||||||
|
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
|
||||||
|
def p = if (wk%2) 1 else 2; def w = wk/p
|
||||||
|
def vl = arch_defvw / (p*width{T})
|
||||||
|
loop_fixed_width{xp, rp, h, vl, wk, modular_kernel{w,p}{., ., hs}}
|
||||||
|
}
|
||||||
|
def transpose_fixed_width{rp:*T, xp:*T, 2, h, hs} = {
|
||||||
|
uninterleave{rp, rp+hs, xp, h}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Transpose a kernel of height w*p from x with stride xst to contiguous r
|
||||||
|
# w and h are named for the result, not argument, to match modular_kernel
|
||||||
|
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{xp0:*T, rp:*T, xst:(u64)} = {
|
||||||
|
def h = arch_defvw / 8
|
||||||
|
def ih = iota{h}; def iw = iota{w}
|
||||||
|
def I = [h]u8; def V = I
|
||||||
|
def e = width{T} / 8
|
||||||
|
# Read rows, modular permutation on each
|
||||||
|
def rotbit{x, l,m,h} = x%l + (x-x%l)*(h/m)%h + x//m*l
|
||||||
|
def wi = w + 2 * ((w-1) + (w&2)) # Inverse mod 32
|
||||||
|
def mpd = rotbit{ih%e + (ih - ih%e)%16*wi%h, e,e*p,h}
|
||||||
|
mp := make{I, if (h==16 or p==1) mpd else rotbit{mpd, 8,16,h}}
|
||||||
|
def rot_mp = {
|
||||||
|
def rot_lane = shuf{16, ., make{I, (ih-e)%16}}
|
||||||
|
def cross = if (h==16) ({x}=>x) else ^{make{I, 16*(ih%16<e)}, .}
|
||||||
|
{mp} => cross{rot_lane{mp}}
|
||||||
|
}
|
||||||
|
def perm = if (h==16) shuf else {
|
||||||
|
def sh = get_modperm_lane_shuf{I**16}
|
||||||
|
def q = tr_quads{p}
|
||||||
|
{x, i} => sh{q{x}, i}
|
||||||
|
}
|
||||||
|
xp := xp0
|
||||||
|
xs := @collect (w) {
|
||||||
|
x := match (p) {
|
||||||
|
{1} => perm{load{*V~~xp, 0}, mp}
|
||||||
|
{2} => { def U = [h/2]u8; perm{load2{*U~~xp, *U~~(xp+w*xst)}, mp} }
|
||||||
|
}
|
||||||
|
xp += xst; mp = rot_mp{mp}
|
||||||
|
x
|
||||||
|
}
|
||||||
|
# Rotate each column by its index
|
||||||
|
rotcol{xs, make{I, (ih // e) % w}}
|
||||||
|
# Permute vectors and store
|
||||||
|
each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}}
|
||||||
|
}
|
||||||
|
def modular_kernel_rev{2,2}{xp0:*T, rp:*T, xst:(u64)} = {
|
||||||
|
def V = [arch_defvw / width{T}]T; def U = n_h{V}
|
||||||
|
xl := @unroll (i to 2) {
|
||||||
|
xp := xp0 + i*xst
|
||||||
|
x := load2{*U~~xp, *U~~(xp + 2*xst)}
|
||||||
|
if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}}
|
||||||
|
}
|
||||||
|
xs := unpack_typed{...unpack_typed{...xl}}
|
||||||
|
each{store{*V~~rp, ., .}, iota{2}, each{~~{V,.},xs}}
|
||||||
|
}
|
||||||
|
def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = {
|
||||||
|
def p = if (hk%2) 1 else 2; def h = hk/p
|
||||||
|
def vl = arch_defvw / (p*width{T})
|
||||||
|
loop_fixed_height{xp, rp, w, vl, hk, modular_kernel_rev{h,p}{., ., ws}}
|
||||||
|
}
|
||||||
|
def transpose_fixed_height{rp:*T, xp:*T, w, ws, 2} = {
|
||||||
|
interleave{T}(*void~~rp, *void~~xp, *void~~(xp+ws), w)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
|
||||||
|
rp:*T = *T~~r0
|
||||||
|
xp:*T = *T~~x0
|
||||||
|
def wT = width{T}
|
||||||
|
def vl = arch_defvw / wT
|
||||||
|
# Transposes with code dedicated to a particular width or height
|
||||||
|
def try_fixed_dim{tr, l, lst, nl, l_max} = {
|
||||||
|
def incl{l} = if (k>4) 1 else l!=4
|
||||||
|
if (l<l_max and incl{l} and l==lst) {
|
||||||
|
if (l == 2) { tr{2}; return{} }
|
||||||
|
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
|
||||||
|
if (has_blend and nl>=vl/2 and (l%2==0 or nl>=vl)) {
|
||||||
|
def try{ls} = {
|
||||||
|
def i = length{ls}>>1
|
||||||
|
if (l < select{ls,i}) try{slice{ls, 0,i}} else try{slice{ls, i}}
|
||||||
|
}
|
||||||
|
def try{{lk}} = tr{lk}
|
||||||
|
try{replicate{{i} => i<l_max and incl{i}, slice{iota{8},3}}}
|
||||||
|
return{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Small width: fixed, or over-reading partial kernel
|
||||||
|
def use_part_w = wT<=16 and k>max{4,8/wT}
|
||||||
|
def w_max = 1 + (if (use_part_w) max{4, k/2-1} else 7)
|
||||||
|
if (has_simd and w < max{w_max, k}) {
|
||||||
|
try_fixed_dim{transpose_fixed_width {rp, xp, ., h, hs}, w, ws, h, w_max}
|
||||||
|
if (use_part_w and h>=kh and w>=k/2 and w<k) {
|
||||||
|
loop_fixed_width{xp, rp, h, kh, ws, kernel_part{w}{., ., k, kh, ws, hs}}
|
||||||
|
return{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Small height: fixed, or kernel with overlapping writes
|
||||||
|
# Overlapping is slower than over-reading, so it's only used when needed
|
||||||
|
if (has_simd and h < max{8, k}) {
|
||||||
|
try_fixed_dim{transpose_fixed_height{rp, xp, w, ws, .}, h, hs, w, 8}
|
||||||
|
if (k>8 and w>=k and h>=kh/2) {
|
||||||
|
loop_fixed_height{xp, rp, w, k, hs, kernel_part_h{h}{., ., k, kh, ws, hs}}
|
||||||
|
return{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Scalar transpose defined in C
|
||||||
|
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
|
||||||
|
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
|
||||||
|
# Full kernels
|
||||||
|
# May have w<k or h<k if not has_blend, or >2D transpose with w!=ws or h!=hs
|
||||||
|
if (has_simd and k!=0 and w>=k and h>=k) {
|
||||||
|
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
|
||||||
|
} else {
|
||||||
|
call_base{rp, xp, w, h}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def transpose{T, k if knum{k}} = transpose{T, tup{k, k}}
|
||||||
|
|
||||||
|
def tr_types = tup{i8, i16, i32, i64}
|
||||||
|
def tr_kernels = if (not avx2) tup{ 8, 8, 4, 0 }
|
||||||
|
else tup{16, tup{8, 16}, 8, 4 }
|
||||||
|
|
||||||
|
exportT{'simd_transpose', each{transpose, tr_types, tr_kernels}}
|
||||||
|
|
||||||
|
exportT{'interleave_fns', each{interleave, tr_types}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user