uCBQN/src/builtins/transpose.c
2024-09-13 19:39:08 +03:00

547 lines
17 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Transpose and Reorder Axes (⍉)
// Transpose
// One length-2 axis: dedicated code
// Boolean: pdep or emulation for height 2; pext for width 2
// SHOULD use a generic implementation if BMI2 not present
// SHOULD optimize other short lengths with pdep/pext and shuffles
// Boolean 𝕩: convert to integer
// SHOULD have bit matrix transpose kernel
// CPU sizes: native or SIMD code
// Large SIMD kernels used when they fit, overlapping for odd sizes
// i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
// COULD use half-width or smaller kernels to improve odd sizes
// Scalar transpose or loop used for overhang of 1
// SHOULD add NEON
// Reorder Axes
// If 𝕨 indicates the identity permutation, return 𝕩
// Simplify: remove length-1 axes; coalesce adjacent and trailing axes
// Empty result or trivial reordering: reshape 𝕩
// Large cells: slow outer loop plus mut_copy
// CPU-sized cells, large last 𝕩 and result axes: strided 2D transposes
// Otherwise, generate indices and select with +⌜ and ⊏
// SHOULD generate for a cell and virtualize the rest to save space
// COULD decompose axis permutations to use 2D transpose when possible
// COULD convert boolean to integer for some axis reorderings
// SHOULD have a small-subarray transposer using one or a few shuffles
// ⍉⁼𝕩: data movement of ⍉ with different shape logic
// 𝕨⍉⁼𝕩: compute inverse 𝕨, length 1+⌈´𝕨
// Under Transpose supports invertible cases
// SHOULD implement Under with duplicate axes, maybe as Under Select
// ⍉˘𝕩 and k⍉˘𝕩 for number k: convert to 0‿a⍉𝕩
// SHOULD convert ⍉ with rank to a Reorder Axes call
// COULD implement fast ⍉⍟n
#include "../core.h"
#include "../utils/mut.h"
#include "../utils/talloc.h"
#include "../builtins.h"
#include "../utils/calls.h"
#ifdef __BMI2__
#if !SLOW_PDEP
#define FAST_PDEP 1
#endif
#include <immintrin.h>
#if USE_VALGRIND
#define _pdep_u64 vg_pdep_u64
#endif
#endif
#define DECL_BASE(T) \
static NOINLINE void transpose_##T(void* rv, void* xv, u64 bw, u64 bh, u64 w, u64 h) { \
T* rp=rv; T* xp=xv; \
if (bh<bw) { PLAINLOOP for(ux y=0;y<bh;y++) NOVECTORIZE for(ux x=0;x<bw;x++) rp[x*h+y] = xp[y*w+x]; } \
else { PLAINLOOP for(ux x=0;x<bw;x++) NOVECTORIZE for(ux y=0;y<bh;y++) rp[x*h+y] = xp[y*w+x]; } \
}
DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64)
#undef DECL_BASE
typedef void (*TranspFn)(void*,void*,u64,u64,u64,u64);
#if SINGELI
#define transposeFns simd_transpose
#define SINGELI_FILE transpose
#include "../utils/includeSingeli.h"
#else
static TranspFn const transposeFns[] = {
transpose_i8, transpose_i16, transpose_i32, transpose_i64
};
#endif
static void interleave_bits(u64* rp, void* x0v, void* x1v, usz n) {
u32* x0 = (u32*)x0v; u32* x1 = (u32*)x1v;
for (usz i=0; i<BIT_N(n); i++) {
#if FAST_PDEP
rp[i] = _pdep_u64(x0[i], 0x5555555555555555) | _pdep_u64(x1[i], 0xAAAAAAAAAAAAAAAA);
#else
#define STEP(V,M,SH) V = (V | V<<SH) & M;
#define EXPAND(V) \
STEP(V, 0x0000ffff0000ffff, 16) \
STEP(V, 0x00ff00ff00ff00ff, 8) \
STEP(V, 0x0f0f0f0f0f0f0f0f, 4) \
STEP(V, 0x3333333333333333, 2) \
STEP(V, 0x5555555555555555, 1)
u64 e0 = x0[i]; EXPAND(e0);
u64 e1 = x1[i]; EXPAND(e1);
rp[i] = e0 | e1<<1;
#undef EXPAND
#undef STEP
#endif
}
}
B toBPtrAny(B x) {
if (arr_bptr(x)!=NULL) return x;
return taga(cpyHArr(x));
}
NOINLINE
B toElTypeArr(u8 re, B x) { // consumes; returns an array with the given element type (re==el_B guarantees TO_BPTR working)
switch (re) { default: UD;
case el_bit: return toBitAny(x);
case el_i8: return toI8Any(x);
case el_i16: return toI16Any(x);
case el_i32: return toI32Any(x);
case el_f64: return toF64Any(x);
case el_c8: return toC8Any(x);
case el_c16: return toC16Any(x);
case el_c32: return toC32Any(x);
case el_B: return toBPtrAny(x);
}
}
// interleave arrays, 𝕨≍⎉(-xk)𝕩
B interleave_cells(B w, B x, ur xr, ur xk, usz* xsh) { // consumes w,x
assert(RNK(w)==xr && xr>=1);
u8 we = TI(w,elType);
u8 xe = TI(x,elType);
u8 re = we==xe? we : el_or(we, xe);
if (0) { goto to_equal_types; to_equal_types:;
// delay doing this until it's known that there will be code that can utilize it
if (re!=we) w = toElTypeArr(re, w);
if (re!=xe) x = toElTypeArr(re, x);
return interleave_cells(w, x, xr, xk, SH(x));
}
Arr *r;
u8 xlw = elwBitLog(re);
usz n = shProd(xsh, 0, xk);
usz csz = shProd(xsh, xk, xr);
usz ia = 2*n*csz;
if (csz & (csz-1)) {
goto generic;
} else if (csz==1 && xlw==0) { // we & xe are trivially el_bit
u64* rp; r=m_bitarrp(&rp, ia);
interleave_bits(rp, bitany_ptr(w), bitany_ptr(x), ia);
#if SINGELI
} else if (csz==1 && re==el_B) {
if (we!=xe) goto to_equal_types;
B* wp = TO_BPTR(w); B* xp = TO_BPTR_RUN(x, xsh = SH(x));
HArr_p p = m_harrUv(ia); // Debug build complains with harrUp
interleave_fns[3](p.a, wp, xp, n);
for (usz i=0; i<ia; i++) inc(p.a[i]);
NOGC_E;
r = (Arr*) p.c;
goto add_fill;
} else if (csz<=64>>xlw && csz<<xlw>=8) { // Require CPU-sized cells
if (we!=xe) goto to_equal_types;
assert(re!=el_B);
void* rv;
if (xlw==0) { u64* rp; r = m_bitarrp(&rp, ia); rv=rp; }
else rv = m_tyarrp(&r,elWidth(re),ia,el2t(re));
interleave_fns[CTZ(csz<<xlw)-3](rv, tyany_ptr(w), tyany_ptr(x), n);
#endif
} else { generic:;
MAKE_MUT_INIT(rm, ia, re); MUTG_INIT(rm);
for (ux o = 0; o < n*csz; o+= csz) {
mut_copyG(rm, o*2, w, o, csz);
mut_copyG(rm, o*2+csz, x, o, csz);
}
r = a(mut_fv(rm));
goto add_fill;
}
if (0) { add_fill:;
if (SFNS_FILLS) r = a(qWithFill(taga(r), fill_both(w, x)));
}
usz* sh = arr_shAlloc(r, xr+1);
shcpy(sh, xsh, xk); sh[xk]=2; shcpy(sh+xk+1, xsh+xk, xr-xk);
decG(w); decG(x);
return taga(r);
}
static void transpose_move(void* rv, void* xv, u8 xe, usz w, usz h) {
assert(xe!=el_bit); assert(xe!=el_B);
transposeFns[elwByteLog(xe)](rv, xv, w, h, w, h);
}
// Return an array with data from x transposed as though it's shape h,w
// Shape of result needs to be set afterwards!
static Arr* transpose_noshape(B* px, usz ia, usz w, usz h) {
B x = *px;
u8 xe = TI(x,elType);
Arr* r;
if (xe==el_B) {
B xf = getFillR(x);
B* xp = TO_BPTR_RUN(x, *px = x);
HArr_p p = m_harrUv(ia); // Debug build complains with harrUp
transpose_move(p.a, xp, el_f64, w, h);
for (usz xi=0; xi<ia; xi++) inc(p.a[xi]); // TODO don't inc when there's a method of freeing a HArr without freeing its elements
NOGC_E;
r=a(qWithFill(p.b, xf));
} else if (xe==el_bit) {
if (h==2) {
u64* rp; r=m_bitarrp(&rp, ia);
Arr* x1o = TI(x,slice)(incG(x),w,w);
interleave_bits(rp, bitany_ptr(x), bitanyv_ptr(x1o), ia);
mm_free((Value*)x1o);
#ifdef __BMI2__
} else if (w==2) {
u64* xp = bitany_ptr(x);
u64* r0; r=m_bitarrp(&r0, ia);
TALLOC(u64, r1, BIT_N(h));
for (usz i=0; i<BIT_N(ia); i++) {
u64 v = xp[i];
((u32*)r0)[i] = _pext_u64(v, 0x5555555555555555);
((u32*)r1)[i] = _pext_u64(v, 0xAAAAAAAAAAAAAAAA);
}
bit_cpyN(r0, h, r1, 0, h);
TFREE(r1);
#endif
} else {
*px = x = taga(cpyI8Arr(x)); xe=el_i8;
void* rv = m_tyarrp(&r,elWidth(xe),ia,el2t(xe));
void* xv = tyany_ptr(x);
transpose_move(rv, xv, xe, w, h);
r = (Arr*)cpyBitArr(taga(r));
}
} else {
void* rv = m_tyarrp(&r,elWidth(xe),ia,el2t(xe));
void* xv = tyany_ptr(x);
transpose_move(rv, xv, xe, w, h);
}
return r;
}
B transp_c1(B t, B x) {
if (RARE(isAtm(x))) return m_unit(x);
ur xr = RNK(x);
if (xr<=1) return x;
usz ia = IA(x);
usz* xsh = SH(x);
usz h = xsh[0];
if (ia==0 || h==1 || h==ia /*w==1*/) {
Arr* r = cpyWithShape(x);
ShArr* sh = m_shArr(xr);
shcpy(sh->a, xsh+1, xr-1);
sh->a[xr-1] = h;
arr_shReplace(r, xr, sh);
return taga(r);
}
usz w = xsh[1] * shProd(xsh, 2, xr);
Arr* r = transpose_noshape(&x, ia, w, h);
usz* rsh = arr_shAlloc(r, xr);
if (xr==2) rsh[0] = w; else shcpy(rsh, SH(x)+1, xr-1);
rsh[xr-1] = h;
decG(x); return taga(r);
}
B mul_c2(B,B,B);
B ud_c1(B,B);
B tbl_c2(Md1D*,B,B);
B select_c2(B,B,B);
static void shSet(Arr* ra, ur rr, ShArr* sh) {
if (RARE(rr <= 1)) arr_shVec(ra);
else arr_shSetUO(ra, rr, sh);
}
B transp_c2(B t, B w, B x) {
usz wia=1;
if (isArr(w)) {
if (RNK(w)>1) thrM("⍉: 𝕨 must have rank at most 1");
wia = IA(w);
if (wia==0) { decG(w); return isArr(x)? x : m_unit(x); }
}
ur xr;
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
// Axis permutation
TALLOC(u8, alloc, xr*(sizeof(ur) + 3*sizeof(usz)) + sizeof(usz)); // ur* p, usz* rsh, usz* st, usz* ri
ur* p = (ur*)alloc;
if (isAtm(w)) {
usz a=o2s(w);
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
if (a==xr-1) { TFREE(alloc); return C1(transp, x); }
p[0] = a;
} else {
SGetU(w)
for (usz i=0; i<wia; i++) {
usz a=o2s(GetU(w, i));
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
p[i] = a;
}
decG(w);
}
B r;
// Compute shape for the given axes
usz* xsh = SH(x);
usz* rsh = ptr_roundUpToEl((usz*)(p + xr)); // Length xr
usz dup = 0, max = 0, id = 0;
usz no_sh = -(usz)1;
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
for (usz i=0; i<wia; i++) {
ur j=p[i];
usz xl=xsh[i], l=rsh[j];
dup += l!=no_sh;
id += i==j;
max = j>max? j : max;
if (xl<l) rsh[j]=xl;
}
if (id == wia) { r = x; goto ret; }
// Fill in remaining axes and check for missing ones
ur rr = xr-dup;
if (max >= rr) thrF("⍉: Skipped result axis");
if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) {
p[i] = j;
rsh[j] = xsh[i];
i++;
}
// Create shape object, saving unprocessed result shape
ShArr* sh ONLY_GCC(=0);
if (LIKELY(rr > 1)) { // Not all duplicates
sh = m_shArr(rr);
shcpy(sh->a, rsh, rr);
}
// Empty result
if (IA(x) == 0) {
Arr* ra = emptyWithFill(getFillR(x));
shSet(ra, rr, sh);
decG(x);
r = taga(ra); goto ret;
}
// Add up stride for each axis
ur na = max + 1; // Number of result axes that moved
usz* st = rsh + xr; // Length na
PLAINLOOP for (usz j=0; j<na; j++) st[j] = 0;
usz csz = shProd(xsh, na+dup, xr);
PLAINLOOP for (usz i=na+dup, c=csz; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
// Simplify axis structure
// p is unused now; work only on csz, rsh, and st
usz *lp = &csz; usz sz = csz;
usz na0=na; usz* rsh0=rsh; usz* st0=st; rsh+=na0; st+=na0; na=0;
PLAINLOOP for (usz i=na0; i--; ) {
usz l = rsh0[i]; if (l==1) continue; // Ignore
usz s = st0[i]; if (s==sz) { *lp*=l; sz*=l; continue; } // Combine with lower
na++; *--rsh=l; *--st=s; lp=rsh; sz=l*s;
}
// Turned out trivial
if (na == 0) {
Arr* ra = TI(x,slice)(x, 0, csz);
shSet(ra, rr, sh);
r = taga(ra); goto ret;
}
u8 xe = TI(x,elType);
#define AXIS_LOOP(N_AX, I_INC, DO_INNER) \
ur a0 = N_AX - 1; \
usz* ri = st+na; PLAINLOOP for (usz i=0; i<a0; i++) ri[i]=0;\
usz l = rsh[a0]; \
for (usz i=0, j0=0;;) { \
/* Hardcode one innermost loop: assume N_AX>=1 */ \
ur a = a0; \
usz str = st[a]; \
for (usz k=0; k<l; k++) { \
usz j=j0+k*str; DO_INNER; \
i += I_INC; \
} \
if (i == ria) break; \
/* Update result index starting with last axis finished */\
while (1) { \
str = st[--a]; \
j0 += str; \
if (LIKELY(++ri[a] < rsh[a])) break; \
ri[a] = 0; \
j0 -= rsh[a] * str; \
} \
}
u8 xlw = elwBitLog(xe);
if (csz >= (32*8) >> xlw) { // cell >= 32 bytes
usz ria = csz * shProd(rsh, 0, na);
MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm);
AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz));
Arr* ra = mut_fp(rm);
shSet(ra, rr, sh);
r = withFill(taga(ra), getFillR(x));
decG(x); goto ret;
}
#undef AXIS_LOOP
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
&& xe!=el_B && na>=2) {
// If some result axis has stride 1 (guaranteed if dup==0), then it
// corresponds to the last argument axis and we have a strided
// transpose swapping that with the last result axis
usz rai = na-1;
usz xai=rai; while (st[--xai]!=1) if (xai==0) goto skip_2d;
if (rsh[xai]*rsh[rai] < (256*8) >> xlw) goto skip_2d;
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
usz rf = shProd(rsh, 0, na);
Arr* ra;
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
u8* xp = tyany_ptr(x);
usz w = rsh[xai]; usz ws = st[rai];
usz h = rsh[rai]; usz hs = shProd(rsh, xai+1, rai) * h;
if (na == 2) {
tran(rp, xp, w, h, ws, hs);
} else {
csz = (csz<<xlw) / 8; // Convert to bytes
usz i_skip = (w-1)*hs*csz;
usz end = rf*csz - i_skip;
ur a0 = na - 1;
if (xlw<3) PLAINLOOP for (usz i=0; i<na; i++) st[i] >>= 3-xlw;
else if (xlw>3) PLAINLOOP for (usz i=0; i<na; i++) st[i] <<= xlw-3;
usz* ri = st+na; PLAINLOOP for (usz i=0; i<a0; i++) ri[i]=0;
for (usz i=0, j=0;;) {
tran(rp+i, xp+j, w, h, ws, hs);
i += h*csz;
if (i == end) break;
for (ur a = a0;;) {
if (--a == xai) { assert(a!=0); i+=i_skip; --a; }
usz str = st[a];
j += str;
if (LIKELY(++ri[a] < rsh[a])) break;
ri[a] = 0;
j -= rsh[a] * str;
}
}
}
shSet(ra, rr, sh);
r = taga(ra);
decG(x); goto ret;
}
skip_2d:;
// Reshape x for selection
ShArr* zsh = m_shArr(2);
zsh->a[0] = IA(x)/csz;
zsh->a[1] = csz;
x = taga(arr_shSetUG(customizeShape(x), 2, zsh));
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
B ind = bi_N;
for (ur k=na; k--; ) {
B v = C2(mul, m_usz(st[k]/csz), C1(ud, m_f64(rsh[k])));
if (q_N(ind)) ind = v;
else ind = M1C2(tbl, add, v, ind);
}
r = C2(select, ind, x);
Arr* ra = cpyWithShape(r); r = taga(ra);
if (rr>1) arr_shReplace(ra, rr, sh);
else { decSh((Value*)ra); arr_shVec(ra); }
ret:;
TFREE(alloc);
return r;
}
B transp_im(B t, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
ur xr = RNK(x);
if (xr<=1) return x;
usz ia = IA(x);
usz* xsh = SH(x);
usz w = xsh[xr-1];
if (ia==0 || w==1 || w==ia /*h==1*/) {
Arr* r = cpyWithShape(x);
ShArr* sh = m_shArr(xr);
sh->a[0] = w;
shcpy(sh->a+1, xsh, xr-1);
arr_shReplace(r, xr, sh);
return taga(r);
}
usz h = xsh[0] * shProd(xsh, 1, xr-1);
Arr* r = transpose_noshape(&x, ia, w, h);
usz* rsh = arr_shAlloc(r, xr);
rsh[0] = w;
if (xr==2) rsh[1] = h; else shcpy(rsh+1, SH(x), xr-1);
decG(x); return taga(r);
}
B transp_uc1(B t, B o, B x) {
return transp_im(m_f64(0), c1(o, transp_c1(t, x)));
}
// Consumes w; return bi_N if w contained duplicates
static B invert_transp_w(B w, ur xr) {
if (isAtm(w)) {
if (xr<1) thrM("⍉⁼: Length of 𝕨 must be at most rank of 𝕩");
usz a=o2s(w);
if (a>=xr) thrF("⍉⁼: Axis %s does not exist (%i≡=𝕩)", a, xr);
i32* wp; w = m_i32arrv(&wp, a);
PLAINLOOP for (usz i=0; i<a; i++) wp[i] = i+1;
} else {
if (RNK(w)>1) thrM("⍉⁼: 𝕨 must have rank at most 1");
usz wia = IA(w);
if (wia==0) return w;
if (xr<wia) thrM("⍉⁼: Length of 𝕨 must be at most rank of 𝕩");
SGetU(w)
TALLOC(ur, p, xr);
for (usz i=0; i<xr; i++) p[i]=xr;
usz max = 0;
for (usz i=0; i<wia; i++) {
usz a=o2s(GetU(w, i));
if (a>=xr) thrF("⍉⁼: Axis %s does not exist (%i≡=𝕩)", a, xr);
if (p[a]!=xr) { TFREE(p); decG(w); return bi_N; } // Handled by caller
max = a>max? a : max;
p[a] = i;
}
decG(w);
usz n = max+1;
i32* wp; w = m_i32arrv(&wp, n);
for (usz i=0, j=wia; i<n; i++) wp[i] = p[i]<xr? p[i] : j++;
TFREE(p);
}
return w;
}
B transp_ix(B t, B w, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
w = invert_transp_w(w, RNK(x));
if (q_N(w)) thrM("⍉⁼: Duplicate axes");
return C2(transp, w, x);
}
B transp_ucw(B t, B o, B w, B x) {
B wi = invert_transp_w(inc(w), isAtm(x)? 0 : RNK(x));
if (q_N(wi)) return def_fn_ucw(t, o, w, x); // Duplicate axes
return C2(transp, wi, c1(o, C2(transp, w, x)));
}
void transp_init(void) {
c(BFn,bi_transp)->im = transp_im;
c(BFn,bi_transp)->ix = transp_ix;
c(BFn,bi_transp)->uc1 = transp_uc1;
c(BFn,bi_transp)->ucw = transp_ucw;
}