Merge pull request #123 from mlochbaum/modperm

CPU-sized select single column and transpose with modular permutations
This commit is contained in:
dzaima 2024-11-06 22:23:20 +02:00 committed by GitHub
commit 163853439e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 604 additions and 110 deletions

@ -1 +1 @@
Subproject commit d432cb710911457169da5342d27ce8adffd5dd1a
Subproject commit 17c512727dbcb6d58a2adadac4661bc9c43920d2

View File

@ -219,8 +219,8 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
// fast special-case implementations
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf); // from select.c
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x
ur xr = RNK(x);
assert(xr>1 && k<xr);
usz* xsh = SH(x);
@ -228,58 +228,15 @@ static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind
usz l = xsh[k];
assert(0<=ind && ind<l);
assert(cam*l*csz == IA(x));
Arr* ra;
usz take = leaf? 1 : csz;
if (l==1 && take==csz) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else {
u8 xe = TI(x,elType);
u8 ewl= elwBitLog(xe);
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
usz ria = cam*take;
if (xl>=7 || (xl<3 && xl>0)) { // generic case
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
usz jump = l * csz;
usz xi = take*ind;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, take);
xi+= jump;
ri+= take;
}
ra = mut_fp(rm);
} else if (xe==el_B) {
assert(take == 1);
SGet(x)
HArr_p rp = m_harrUv(ria);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
NOGC_E; ra = (Arr*)rp.c;
} else {
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
void* xp = tyany_ptr(x);
switch(xl) {
case 0:
#if SINGELI
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
else
#endif
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
break;
case 3: PLAINLOOP for (usz i=0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
case 4: PLAINLOOP for (usz i=0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
case 5: PLAINLOOP for (usz i=0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
case 6: PLAINLOOP for (usz i=0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
}
}
}
B r = select_cells_single(ind, x, cam, l, csz, leaf);
Arr* ra = a(r);
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
if (rsh) {
shcpy(rsh, xsh, k);
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
}
decG(x);
return taga(ra);
return r;
}
static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores

View File

@ -40,8 +40,15 @@
// Sparse initialization if 𝕨 is much smaller than 𝕩
// COULD call Mark Firsts (∊) for very short 𝕨 to avoid allocation
// Select Cells - inds⊸⊏⎉1 x
// Squeeze indices if too wide for given x
// Select Cells - inds⊸⊏⎉1 𝕩
// Squeeze indices if too wide for given 𝕩
// Single index: (also used for monadic ⊏˘ ⊣˝˘ ⊢˝˘)
// Selecting a column of bits:
// Row size <64: extract as with fold-cells
// Selecting a column of 1, 2, 4, or 8-byte elements:
// Short cells: pack vectors from 𝕩, or blend and permute
// Long cells: dedicated scalar loop for each type
// Otherwise, loop with mutable copy
// Boolean indices:
// Short inds and short cells: Widen to i8
// Otherwise: bitsel call per cell
@ -57,7 +64,7 @@
// COULD generate full list of indices via arith
// 1-element cells: use (≠inds)/⥊x after checking ∧´inds∊0‿¯1
// Used for ⌽⎉1
// SHOULD use for atom⊸⊏⎉k, /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
// SHOULD use for /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
#include "../core.h"
#include "../utils/talloc.h"
@ -575,6 +582,62 @@ static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same e
B slash_c2(B, B, B);
B select_cells_base(B inds, B x0, ux csz, ux cam);
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf) { // ⥊ ind {leaf? <∘⊑; ⊏}˘ cam‿l‿csz ⥊ x
usz take = leaf? 1 : csz;
Arr* ra;
if (l==1 && take==csz) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else {
u8 xe = TI(x,elType);
u8 ewl= elwBitLog(xe);
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
usz ria = cam*take;
if (xl>=7 || (xl<3 && xl>0)) { // generic case
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
usz jump = l * csz;
usz xi = take*ind;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, take);
xi+= jump;
ri+= take;
}
ra = mut_fp(rm);
} else if (xe==el_B) {
assert(take == 1);
SGet(x)
HArr_p rp = m_harrUv(ria);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
NOGC_E; ra = (Arr*)rp.c;
} else {
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
void* xp = tyany_ptr(x);
if (xl == 0) {
#if SINGELI
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
else
#endif
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
} else {
usz i0 = 0;
#if SINGELI
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
#endif
switch(xl) { default: UD;
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
case 4: PLAINLOOP for (usz i=i0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
case 5: PLAINLOOP for (usz i=i0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
case 6: PLAINLOOP for (usz i=i0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
}
}
}
}
return taga(ra);
}
#define CLZC(X) (64-(CLZ((u64)(X))))
@ -851,6 +914,11 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸
ux in = IA(inds);
if (in == 0) return taga(emptyArr(x, 1));
u8 ie = TI(inds,elType);
if (in == 1) {
B w = IGetU(inds,0); if (!isF64(w)) goto generic;
B r = select_cells_single(WRAP(o2i64(w), csz, thrF("⊏: Indexing out-of-bounds (%R∊𝕨, %s≡≠𝕩)", w, csz)), x, cam, csz, 1, false);
decG(x); decG(inds); return r;
}
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
inds = num_squeeze(inds);
ie = TI(inds,elType);

View File

@ -1,18 +1,22 @@
// Transpose and Reorder Axes (⍉)
// Transpose
// One length-2 axis: dedicated code
// Boolean: pdep or emulation for height 2; pext for width 2
// SHOULD use a generic implementation if BMI2 not present
// SHOULD optimize other short lengths with pdep/pext and shuffles
// Boolean 𝕩: convert to integer
// pdep or emulation for height 2; pext for width 2
// SHOULD switch to shuffles and modular permutation
// SHOULD have bit matrix transpose kernel
// CPU sizes: native or SIMD code
// Large SIMD kernels used when they fit, overlapping for odd sizes
// i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
// COULD use half-width or smaller kernels to improve odd sizes
// Scalar transpose or loop used for overhang of 1
// SHOULD add NEON
// SSE, NEON i8: 8×8 ; i16: 8×8; i32: 4×4; f64: scalar
// AVX2 i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
// COULD use half-width or smaller kernels to improve odd sizes
// Scalar transpose or loop used for overhang of 1
// Partial kernels for at least half-kernel height/width
// Short width: over-read, then skip some writes
// Short height, i8 AVX2 only: overlap input rows and output words
// Smaller heights and widths: dedicated kernels
// Zipping, unzipping, shuffling for powers of 2
// Modular permutation for odd numbers, possibly times 2
// Reorder Axes
// If 𝕨 indicates the identity permutation, return 𝕩

View File

@ -74,6 +74,189 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
export{'si_sum_f64', fold_assoc_0{f64,+}}
def extract_column_pow2{T, x0, r0, nv, k} = {
def V = [arch_defvw / width{T}]T
xv := *V~~x0
@for (r in *V~~r0 over i to nv) {
xs := each{load{xv, .}, iota{k}}
def unzip0{w} = if (not hasarch{'X86_64'}) {
unzip{..., 0} # Sane instruction set
} else {
if (w <= 16) {
# Pack instructions
m := make{V, - (iota{vcount{V}}%k == 0)}
xs = each{&{m, .}, xs} # Mask off high bits
def D = el_m{V}
{a, b} => packQ{D~~a, D~~b}
} else {
# Two-vector shuffles
# Could also be used for 1/2-byte with ending gap >= 4 bytes,
# less instructions but it doesn't seem faster
def c = 128/w
def sh = shuf{[c]ty_f{w}, ., 2*iota{c} % c}
{...ab} => sh{ab}
}
}
if (not hasarch{'X86_64'} or T != u16 or hasarch{'SSE4.1'}) {
r = tree_fold{unzip0{width{T}}, xs}
} else {
# No unsigned saturation: sign-extend then use unsigned
def D = [4]i32
def f = tree_fold{unzip0{32}, .}
def proc{hx} = {
ri := D~~f{hx}
top := D**(1<<15); m := D**(1<<16 - 1)
(ri & m) | (D~~(ri&top == top) &~ m)
}
r = V~~packQ{...each{proc, split{k/2, xs}}}
}
if (width{V} > 128) { # Lane axis wasn't packed, need to shuffle to bottom
def tr{E,a, r} = shuf{[1<<a]E, r, tr_iota{shiftright{a-1, iota{a}}}}
def lc = k > 4
if (lc) r = tr{u64,2, r}
r = tr{ty_u{128/k}, lb{k} + (not lc), r}
}
xv += k
}
}
def extract_column_modperm{x0, r0, nv, l, el, vl} = {
# Build modular permutations
def V = [vl]u8
def H = [vl/2]u16
p2 := ctz{l}; l >>= p2 # Decompose into l<<p2 with odd l
e := p2 + promote{ux,el} # Absorb into element size for most computation
l8 := cast_i{u8, l}
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
elo:= V**(u8~~1<<e - 1)
ie := iota{V} & elo
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
def mu16{k} = {
k16 := H ** k
prd := shuf{V~~(kmul * k16), 0,0}
if (e == 0) prd += V~~(k16 << 8)
(prd & V**(vl-1)) + ie
}
si := mu16{l8}
sii := mu16{li}
def swap_ms = if (vl == 16) ({x}=>x) else {
ms := (V**16 & sii) == (V**16 &~ iota{V})
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
}
# Blend masks
def mg = { # Iteration i should select where mg == V**i
ss := (si < V**(l8<<e)) & (ie == V**0)
vs := V**0xff - scan_assoc_id0{+}{ss}
swap_ms{shuf{[16]u8, vs, sii}}
}
mgo := mg - V**(l8 & 3)
mgm := (mgo - V**1) & V**3
m4s := @collect (i to 3) mgm == V**i
# Main loop
def loop{output} = {
xv := *V~~x0
rv := *V~~r0
# Each iteration handles l vectors from x
@for (i to nv<<p2) {
# 1 or 3 initial vectors
r := load{xv,0}; ++xv
if ((l & 2) != 0) {
def {m0, _, m2} = m4s
re := homBlend{load{xv,0}, load{xv,1}, m2}
r = homBlend{re, r, m0}
xv += 2
}
# Then the rest in groups of 4
mh := mgo
@for (l/4) {
{l0, ...ls} := each{load{xv,.}, iota{4}}
r4 := fold{homBlend, l0, ls, m4s}
r = homBlend{r, r4, mh < V**4}
mh -= V**4; xv += 4
}
def write{r} = {
store{rv, 0, shuf{[16]u8, r, si}}
++rv
}
output{r, i, write}
}
}
# Handle odd and even strides separately
if (p2 == 0) {
loop{{r, i, write} => write{swap_ms{r}}}
} else {
# Store results
def add_res = {
ra := V**0 # Accumulator
i := make{V, iota{vl}%16}
o := V**(u8~~1<<el)
bl := i & elo < o
sh := i - o
{r} => { ra = homBlend{shuf{[16]u8, ra, sh}, r, bl} }
}
il := make{V, iota{vl} % 16}
# Shuffle to undo interleaving of add_res
def __shr{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x >> sh)
def __shl{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x << sh)
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8,-sh}
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8, sh}
def uz_lane = {
l := V**(u8~~1<<el - 1)
h := V**(16 - u8~~16>>p2)
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
dz |= (il &~ (l | h))<<p2 # middle->high
shuf{[16]u8, ., dz}
}
# Adjust modular permutation to apply after unzipping
si = uz_lane{(si & V**(vl - u8~~1<<e)) >> p2}
si ^= il &~ V**((16 - u8~~1<<e) >> p2)
# Cross-lane follow-up
def cross = if (vl == 16) { {x}=>x } else {
si = shuf{[4]u64, si, 0,1,0,1}
assert{p2 <= 2}
cr := make{[8]u32, tr_iota{0,2,1}}
if (p2 > 1) cr = make{[8]u32, tr_iota{2,0,1}}
shuf{[8]u32, ., cr}
}
# Run, writing every 1<<p2 steps
plo := usz~~1<<p2 - 1
loop{{r, i, write} => {
def ra = add_res{r}
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
}}
}
}
# Select one element out of every l, element width 1<<el bytes
# Maximum of n result values, return actual number written
fn extract_column(x0:*void, r0:*void, n:usz, l:usz, el:u8) : usz = {
n <<= el
def vl = arch_defvw / 8
def thr = min{vl+2, 20}
if ((not has_simd) or n < vl or l > usz~~thr>>el or l<<el >= thr) return{0}
nv := n / vl
if (has_simd and (l & (l-1)) == 0) {
def try_unzip{T, k} = if (k < thr and l == k) {
extract_column_pow2{T, x0, r0, nv, k}
goto{'ret'}
}
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
@unroll (ek to 4) if (el == ek) {
def T = ty_u{8<<ek}
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
}
return{0}
setlabel{'ret'}
} else if (hasarch{'SSE4.1'} or hasarch{'AARCH64'}) {
extract_column_modperm{x0, r0, nv, l, el, vl}
} else return{0}
(usz~~vl>>el) * nv
}
# Short-row boolean folds: main challenge is bit packing
def fold_rows_bit_lt64{
op, run_loop2, run_loop4, pext_res, mult_in,
@ -184,7 +367,7 @@ def fold_rows_bit_lt64{
}
}
fn select_rows_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
fn extract_column_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
assert{l < 64}; assert{o < l} # Row length, and offset within row
def run_loop2{loop} = loop{{a,b} => a>>o}
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
@ -329,4 +512,5 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
}
export{'si_xor_rows_bit', xor_rows_bit}
export{'si_or_rows_bit', or_rows_bit}
export{'si_select_cells_bit_lt64', select_rows_bit_lt64}
export{'si_select_cells_bit_lt64', extract_column_bit_lt64}
export{'si_select_cells_byte', extract_column}

View File

@ -693,7 +693,7 @@ def rep_const_bool_small_odd{W=[wl](u64), max_wv, wv, get_perm_x, output} = {
def fixed_loop{k} = {
assert{wv == k}
while (1) {
# e.g. 01234567 to 05316427 on each byte for k==3, ew==8
# e.g. 01234567 to 03614725 on each byte for k==3, ew==8
xv := get_perm_x{}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word

View File

@ -74,16 +74,7 @@ export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', sca
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
# Assumes identity is 0
def scan_assoc{op} = {
def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise
def shl0{v:V, k==128 if hasarch{'AVX2'}} = {
# Broadcast end of lane 0 to entire lane 1
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
}
prefix_byshift{op, shl0}
}
def scan_plus = scan_assoc{+}
def scan_plus = scan_assoc_id0{+}
# Associative scan
def scan_assoc_0 = scan_scal

View File

@ -66,3 +66,13 @@ def make_scan_idem{(f64), op, up} = {
sc
}
def make_scan_idem{T, op} = make_scan_idem{T, op, 1}
def scan_assoc_id0{op} = {
def shl0{v:[_]T, k} = vec_shift_right_128{v, k/width{T}} # Lanewise
def shl0{v:V, k==128 if hasarch{'AVX2'}} = {
# Broadcast end of lane 0 to entire lane 1
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
}
prefix_byshift{op, shl0}
}

View File

@ -4,6 +4,8 @@ include './f64'
include './mask'
include './bitops'
def avx2 = hasarch{'AVX2'}
# Group l (power of 2) elements into paired groups of length o
# e.g. pairs{2, iota{8}} = {{0,1,4,5}, {2,3,6,7}}
def pairs{o, x} = {
@ -22,31 +24,82 @@ def permute_pass{o, x} = {
merge{h{0,2}, h{1,3}}
}
def unpack_to{f, l, x} = {
def pass = if (f) permute_pass else unpack_pass
def pass = if (avx2 and f) permute_pass else unpack_pass
pass{l, if (l==1) x else unpack_to{0, l/2, x}}
}
# Last pass for square kernel packed in halves
def shuf_pass{x} = each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
def halved_pass{n, x} = {
if (not avx2) unpack_pass{n/2, x}
else each{{v} => shuf{[4]i64, v, 0,2,1,3}, x}
}
# Square kernel where width is a full vector
def transpose_square{VT, l, x if hasarch{'AVX2'}} = unpack_to{1, l/2, x}
def transpose_square{VT, l, x if avx2} = unpack_to{1, l/2, x}
def load2{a:*T, b:*T} = pair{load{a}, load{b}}
def store2{a:*T, b:*T, v:T2 if w128i{T} and w256{T2}} = {
each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
def load2{a:*T, b:*T} = match (width{T}) {
{64} => {
def v = each{{p}=>loadLow{*[2]u64~~p, 64}, tup{a,b}}
n_d{T}~~zip{...v, 0}
}
{128} => pair{load{a}, load{b}}
}
def load_k {VT, src, l, w if w256{VT}} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
def store_k{VT, dst, x, l, h if w256{VT}} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
def load_k {VT, src, l, w if w128{VT}} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
def store_k{VT, dst, x, l, h if w128{VT}} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
def store2{a:*T, b:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
{ 64} => each{{p, v} => storeLow{*u64~~p, 64, [2]u64~~v}, tup{a,b}, tup{v, shuf{u64, v, 1,0}}}
{128} => each{{p, i} => store{p, 0, T~~half{v,i}}, tup{a,b}, iota{2}}
}
def store1of2{a:*T, v:T2 if 2*width{T} == width{T2}} = match (width{T}) {
{ 64} => storeLow{*u64~~a, 64, [2]u64~~v}
{128} => store{a, 0, T~~half{v,0}}
}
def load_k {VT, src, l, w} = each{{i} =>load {*VT~~(src+i*w), 0 }, iota{l}}
def store_k{VT, dst, x, l, h} = each{{i,v}=>store{*VT~~(dst+i*h), 0, VT~~v}, iota{l}, x}
def load_k {VT, src, l, w if width{VT} < arch_defvw} = each{{i} =>{p:=src+ i*w; load2 {*VT~~p, *VT~~(p+l*w) }}, iota{l}}
def store_k{VT, dst, x, l, h if width{VT} < arch_defvw} = each{{i,v}=>{p:=dst+2*i*h; store2{*VT~~p, *VT~~(p+ h), v}}, iota{l}, x}
# Transpose kernel of size kw,kh in size w,h array
def kernel{src:*T, dst:*T, kw, kh, w, h} = {
def n = (kw*kh*width{T}) / 256 # Number of vectors
def kernel_part{part_w}{src:*T, dst:*T, kw, kh, w, h} = {
def n = (kw*kh*width{T}) / arch_defvw # Number of vectors
def xvs = load_k{[kw]T, src, n, w}
def xt = unpack_to{n==kh, n/2, xvs} # Transpose n by n
def rvs = if (n==kw) xt else shuf_pass{xt} # To kh by kh for packed square
store_k{[kh]T, dst, rvs, n, h}
def rvs = if (n==kw) xt else halved_pass{n,xt} # To kh by kh for packed square
def stores = store_k{[kh]T, ..., h}
if (same{part_w, 0}) {
stores{dst, rvs, n}
} else {
# Write w results, kw/2 <= n < kw
d := dst
def vd = kw / n # Number of writes for each output vector (1 or 2)
def store_slice{rv, len} = {
stores{d, slice{rv,0,len}, len}
d += len*vd*h
}
store_slice{rvs, n/2} # Unconditionally store first half
rt := slice{rvs,n/2} # Remaining tail
def wtail{b} = {
if ((part_w & (vd*b)) != 0) {
store_slice{rt, b}
slice{rt,0,b} = slice{rt,b,2*b}
}
if (b>1) wtail{b/2}
}
wtail{n/4}
if (vd>1 and (part_w & 1) != 0) store1of2{*[kh]T~~d, select{rt,0}}
}
}
def kernel = kernel_part{0}
def kernel_part_h{part_h}{src:*T==i8, dst:*T, kw==16, kh==16, w, h} = {
def n = (kw*kh*width{T}) / arch_defvw
def VT = [kw]T
off := part_h - kh/2
def xvs = @unroll (i to n) { s := src + i*w; load2{*VT~~s, *VT~~(s+off*w)} }
def rvs = halved_pass{n, unpack_to{0, n/2, xvs}}
@unroll (j to 2) {
def is = 2*iota{2} + j
d := dst + j*off
def store_q{v, i} = { store{*u64~~d, 0, extract{v,i}}; d += h }
each{{r} => each{store_q{[4]u64~~r,.}, is}, rvs}
}
}
@ -153,37 +206,264 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
}
}
# Interleave n values of type T from x0 and x1 into r
# Unzip 2*n values
def uninterleave{r0:*T, r1:*T, xp:*T, n} = {
@for (r0, r1 over i to n) {
r0 = load{xp, i*2}; r1 = load{xp, i*2+1}
}
}
# Zip n values of type T from each of x0 and x1 into r
fn interleave{T}(r0:*void, x0:*void, x1:*void, n:u64) : void = {
rp := *T~~r0
@for (x0 in *T~~x0, x1 in *T~~x1 over i to n) {
store{rp, i*2, x0}; store{rp, i*2+1, x1}
}
}
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
# Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
rp:*T = *T~~r0
xp:*T = *T~~x0
if (hasarch{'AVX2'} and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
} else {
if (h==2 and h==hs) interleave{T}(r0, x0, *void~~(xp+ws), w)
else if (w==2 and w==ws) @for (r0 in rp, r1 in rp+hs over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} }
else call_base{rp, xp, w, h}
# SIMD implementations
def uninterleave{r0:*T, r1:*T, xp:*T, n if has_simd and (not hasarch{'X86_64'} or width{T}>=32 or hasarch{'SSSE3'})} = {
def l = arch_defvw / width{T}
def V = [l]T
rv0 := *V~~r0; rv1 := *V~~r1; xv := *V~~xp
nv := n / l
def uz = if (not hasarch{'X86_64'}) unzip else ({...xs} => {
def reinterpret{V, xs if ktup{xs}} = each{~~{V,.}, xs}
def q = tr_quads{arch_defvw/128}
def k = flat_table{+, iota{2}, 2 * iota{64 / width{T}}}
def px = each{shuf{., k}, xs}
V~~each{q, zip128{...re_el{u64,V}~~px}}
})
@for (r0 in rv0, r1 in rv1 over i to nv) {
tup{r0, r1} = uz{...each{load{xv+2*i, .}, iota{2}}}
}
if (n % l > 0) {
xb := xv + 2*nv
x0 := load{xb}
x1 := V**0; if (n&(l/2) != 0) x1 = load{xb, 1}
mask := maskOf{V, n%l}
each{homMaskStoreF{., mask, .}, tup{rv0+nv,rv1+nv}, uz{x0,x1}}
}
}
fn interleave{T if has_simd}(r:*void, x0:*void, x1:*void, n:u64) : void = {
def l = arch_defvw / width{T}
def V = [l]T
xv0 := *V~~x0; xv1 := *V~~x1; rv := *V~~r
nv := n / l
def q = tr_quads{arch_defvw/128}
@for (x0 in xv0, x1 in xv1 over i to nv) {
each{store{rv+2*i, ., .}, iota{2}, zip128{q{x0},q{x1}}}
}
if (n % l > 0) {
def xs = each{load{.,nv}, tup{xv0,xv1}}
def get_r = zip128{...each{q, xs}, .}
r0 := get_r{0}
rb := rv + 2*nv
nr := 2*n; m := nr%l
mask := maskOf{V, m}
if (nr&l == 0) {
homMaskStoreF{rb, mask, r0}
} else {
store{rb, 0, r0}
if (m > 0) homMaskStoreF{rb+1, mask, get_r{1}}
}
}
}
def transpose{T, k} = transpose{T, k, k}
# Utilities for kernels based on modular permutation
def rotcol{xs, mg:I} = {
def w = length{xs}
@unroll (kl to ceil_log2{w}) { def k = 1<<kl
def vk = I**k; def m = (mg & vk) == vk
def bl{x,y} = { x = homBlend{x,y,m} }
x0 := select{xs, 0}
def xord = select{xs, k*iota{w} % w}
each{bl, xord, shiftleft{xord, tup{x0}}}
}
}
def get_modperm_lane_shuf{c} = {
def cross{s, i} = homBlend{s, shuf{[4]u64, s, 2,3,0,1}, i&c == c}
{x, i} => cross{shuf{16, x, i}, i}
}
def tr_quads = match { {1}=>({x}=>x); {2}=>shuf{[4]u64, ., 0,2,1,3} }
exportT{'simd_transpose', tup{
transpose{i8 , 16},
transpose{i16, 8, 16},
transpose{i32, 8},
transpose{i64, 4}
}}
def loop_fixed_height{xp, rp, w, k, st, kern} = {
@for_mult_max{k, w-k} (i to w+(-w)%k) kern{xp+i, rp+i*st}
}
def loop_fixed_width{xp, rp, h, k, st, kern} = {
@for_mult_max{k, h-k} (i to h+(-h)%k) kern{xp+i*st, rp+i}
}
exportT{'interleave_fns', each{interleave, tup{i8, i16, i32, i64}}}
# Transpose a contiguous kernel of width w*p from x to r with stride rst
def modular_kernel{w,p if w%2==1 and 2%p==0}{xp:*T, rp0:*T, rst:(u64)} = {
def h = arch_defvw / 8
def ih = iota{h}; def iw = iota{w}
def I = [h]u8; def V = I
def e = width{T} / 8
# Load a shape h,w slice of x, but consider as shape w,h
def xsp = each{load{*V~~xp, .}, iw}
# Modular permutation of (reshaped argument) columns
xs := select{xsp, find_index{h/e*iw % w, iw}}
# Rotate each column by its index
rotcol{reverse{xs}, make{I, (ih // e) % w}}
# Modular permutation of rows, and write to result
rp := rp0
mp := make{I, ih%e + (p*(ih - ih%e) + (ih//(h/p))*e)%16*w%h}
mi := I**e
def perm = if (h==16) shuf else {
def sh = get_modperm_lane_shuf{I**16}
def q = tr_quads{p}; if (p>1) mp = q{mp}
{x, i} => q{sh{x, i}}
}
def perm_store{x} = {
match (p) {
{1} => store{*V~~rp, 0, perm{x, mp}}
{2} => { def U = [h/2]u8; store2{*U~~rp, *U~~(rp+w*rst), perm{x, mp}} }
}
rp += rst; mp += mi
if (hasarch{'AARCH64'}) mp &= I**15 # Implicit on x86, value stays below h+w
}
each{perm_store, xs}
}
def modular_kernel{2,2}{xp:*T, rp0:*T, rst:(u64)} = {
def h = arch_defvw / 8
def V = [h]u8; def U = n_h{V}
def ih = iota{h}%16; def e = width{T} / 8
# Permutation to unzip by 4 within each lane
uz := make{V, (4*(ih//e) + ih//(16/4))*e%16 + ih%e}
# Unzipping code for the resulting 4-byte units
def {st, proc, zipx} = match (h) {
{16} => tup{2, {x} => [4]f32~~x, {xs,i} => V~~zip128{...xs,i}}
{32} => tup{1, shuf{[4]u64, ., 0,2,1,3}, {xs,i} => shuf{[4]f32, xs, i + tup{0,2,0,2}}}
}
def xsp = each{load{*V~~xp, .}, iota{2}}
xs := each{proc, each{shuf{16, ., uz}, xsp}}
@unroll (i to 2) {
rp := rp0 + st*i*rst
store2{*U~~rp, *U~~(rp + (2/st)*rst), zipx{xs,i}}
}
}
def transpose_fixed_width{rp:*T, xp:*T, wk, h, hs} = {
def p = if (wk%2) 1 else 2; def w = wk/p
def vl = arch_defvw / (p*width{T})
loop_fixed_width{xp, rp, h, vl, wk, modular_kernel{w,p}{., ., hs}}
}
def transpose_fixed_width{rp:*T, xp:*T, 2, h, hs} = {
uninterleave{rp, rp+hs, xp, h}
}
# Transpose a kernel of height w*p from x with stride xst to contiguous r
# w and h are named for the result, not argument, to match modular_kernel
def modular_kernel_rev{w,p if w%2==1 and 2%p==0}{xp0:*T, rp:*T, xst:(u64)} = {
def h = arch_defvw / 8
def ih = iota{h}; def iw = iota{w}
def I = [h]u8; def V = I
def e = width{T} / 8
# Read rows, modular permutation on each
def rotbit{x, l,m,h} = x%l + (x-x%l)*(h/m)%h + x//m*l
def wi = w + 2 * ((w-1) + (w&2)) # Inverse mod 32
def mpd = rotbit{ih%e + (ih - ih%e)%16*wi%h, e,e*p,h}
mp := make{I, if (h==16 or p==1) mpd else rotbit{mpd, 8,16,h}}
def rot_mp = {
def rot_lane = shuf{16, ., make{I, (ih-e)%16}}
def cross = if (h==16) ({x}=>x) else ^{make{I, 16*(ih%16<e)}, .}
{mp} => cross{rot_lane{mp}}
}
def perm = if (h==16) shuf else {
def sh = get_modperm_lane_shuf{I**16}
def q = tr_quads{p}
{x, i} => sh{q{x}, i}
}
xp := xp0
xs := @collect (w) {
x := match (p) {
{1} => perm{load{*V~~xp, 0}, mp}
{2} => { def U = [h/2]u8; perm{load2{*U~~xp, *U~~(xp+w*xst)}, mp} }
}
xp += xst; mp = rot_mp{mp}
x
}
# Rotate each column by its index
rotcol{xs, make{I, (ih // e) % w}}
# Permute vectors and store
each{store{*V~~rp, ., .}, iw, select{xs, h/e*iw % w}}
}
def modular_kernel_rev{2,2}{xp0:*T, rp:*T, xst:(u64)} = {
def V = [arch_defvw / width{T}]T; def U = n_h{V}
xl := @unroll (i to 2) {
xp := xp0 + i*xst
x := load2{*U~~xp, *U~~(xp + 2*xst)}
if (arch_defvw==128) x else shuf{[8]u32, x, tr_iota{1,2,0}}
}
xs := unpack_typed{...unpack_typed{...xl}}
each{store{*V~~rp, ., .}, iota{2}, each{~~{V,.},xs}}
}
def transpose_fixed_height{rp:*T, xp:*T, w, ws, hk} = {
def p = if (hk%2) 1 else 2; def h = hk/p
def vl = arch_defvw / (p*width{T})
loop_fixed_height{xp, rp, w, vl, hk, modular_kernel_rev{h,p}{., ., ws}}
}
def transpose_fixed_height{rp:*T, xp:*T, w, ws, 2} = {
interleave{T}(*void~~rp, *void~~xp, *void~~(xp+ws), w)
}
fn transpose{T, {k, kh}}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
rp:*T = *T~~r0
xp:*T = *T~~x0
def wT = width{T}
def vl = arch_defvw / wT
# Transposes with code dedicated to a particular width or height
def try_fixed_dim{tr, l, lst, nl, l_max} = {
def incl{l} = if (k>4) 1 else l!=4
if (l<l_max and incl{l} and l==lst) {
if (l == 2) { tr{2}; return{} }
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
if (has_blend and nl>=vl/2 and (l%2==0 or nl>=vl)) {
def try{ls} = {
def i = length{ls}>>1
if (l < select{ls,i}) try{slice{ls, 0,i}} else try{slice{ls, i}}
}
def try{{lk}} = tr{lk}
try{replicate{{i} => i<l_max and incl{i}, slice{iota{8},3}}}
return{}
}
}
}
# Small width: fixed, or over-reading partial kernel
def use_part_w = wT<=16 and k>max{4,8/wT}
def w_max = 1 + (if (use_part_w) max{4, k/2-1} else 7)
if (has_simd and w < max{w_max, k}) {
try_fixed_dim{transpose_fixed_width {rp, xp, ., h, hs}, w, ws, h, w_max}
if (use_part_w and h>=kh and w>=k/2 and w<k) {
loop_fixed_width{xp, rp, h, kh, ws, kernel_part{w}{., ., k, kh, ws, hs}}
return{}
}
}
# Small height: fixed, or kernel with overlapping writes
# Overlapping is slower than over-reading, so it's only used when needed
if (has_simd and h < max{8, k}) {
try_fixed_dim{transpose_fixed_height{rp, xp, w, ws, .}, h, hs, w, 8}
if (k>8 and w>=k and h>=kh/2) {
loop_fixed_height{xp, rp, w, k, hs, kernel_part_h{h}{., ., k, kh, ws, hs}}
return{}
}
}
# Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
# Full kernels
# May have w<k or h<k if not has_blend, or >2D transpose with w!=ws or h!=hs
if (has_simd and k!=0 and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
} else {
call_base{rp, xp, w, h}
}
}
def transpose{T, k if knum{k}} = transpose{T, tup{k, k}}
def tr_types = tup{i8, i16, i32, i64}
def tr_kernels = if (not avx2) tup{ 8, 8, 4, 0 }
else tup{16, tup{8, 16}, 8, 4 }
exportT{'simd_transpose', each{transpose, tr_types, tr_kernels}}
exportT{'interleave_fns', each{interleave, tr_types}}