Merge pull request #71 from mlochbaum/transpose

Reorder Axes
This commit is contained in:
dzaima 2023-03-31 16:29:07 +03:00 committed by GitHub
commit 5367845753
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 472 additions and 178 deletions

View File

@ -576,7 +576,7 @@ cachedBin‿linkerCache ← {
"xa""src/builtins/arithd.c""dyarith", "xa""src/builtins/cmp.c""cmp", "xa""src/builtins/squeeze.c""squeeze"
"x.""src/builtins/select.c""select", "x.""src/builtins/fold.c""fold", "x.""src/builtins/scan.c""scan"
"x.""src/builtins/scan.c""neq", "x.""src/builtins/slash.c""slash", "x.""src/builtins/slash.c""constrep"
"x.""src/builtins/transpose.c""transpose"
"xa""src/builtins/transpose.c""transpose"
objs

View File

@ -322,6 +322,12 @@ static NOINLINE B match_cells(bool ne, B w, B x, ur wr, ur xr, usz len) {
return r;
}
B transp_c2(B, B, B);
static B transp_cells(ur ax, B x) {
i8* wp; B w=m_i8arrv(&wp, 2); wp[0]=0; wp[1]=ax;
return C2(transp, w, x);
}
B shape_c1(B, B);
B fold_rows(Md1D* d, B x); // From fold.c
B cell_c1(Md1D* d, B x) { B f = d->f;
@ -362,6 +368,7 @@ B cell_c1(Md1D* d, B x) { B f = d->f;
B xf = getFillR(x);
if (!noFill(xf)) return shift_cells(xf, x, TI(x,elType), rtid);
}
if (rtid==n_transp) return xr<=2? x : transp_cells(xr-1, x);
if (TY(f) == t_md1D) {
Md1D* fd = c(Md1D,f);
u8 rtid = fd->m1->flags-1;
@ -425,6 +432,7 @@ B cell_c2(Md1D* d, B w, B x) { B f = d->f;
}
if (rtid==n_take && xr>1 && isF64(w)) return takedrop_highrank(1, m_hVec2(m_f64(SH(x)[0]), w), x);
if (rtid==n_drop && xr>1 && isF64(w)) return takedrop_highrank(0, m_hVec2(m_f64(0), w), x);
if (rtid==n_transp && q_usz(w)) { usz a=o2sG(w); if (a<xr-1) return transp_cells(a+1, x); }
}
S_SLICES(x)
M_HARR(r, cam);

View File

@ -1,7 +1,43 @@
// Transpose and Reorder Axes (⍉)
// Transpose
// One length-2 axis: dedicated code
// Boolean: pdep for height 2; pext for width 2
// SHOULD use a generic implementation if BMI2 not present
// SHOULD optimize other short lengths with pdep/pext and shuffles
// Boolean 𝕩: convert to integer
// SHOULD have bit matrix transpose kernel
// CPU sizes: native or SIMD code
// Large SIMD kernels used when they fit, overlapping for odd sizes
// i8: 16×16; i16: 16×8; i32: 8×8; f64: 4×4
// COULD use half-width or smaller kernels to improve odd sizes
// Scalar transpose or loop used for overhang of 1
// Reorder Axes
// If 𝕨 indicates the identity permutation, return 𝕩
// Simplify: remove length-1 axes; coalesce adjacent and trailing axes
// Empty result or trivial reordering: reshape 𝕩
// Large cells: slow outer loop plus mut_copy
// CPU-sized cells, large last 𝕩 and result axes: strided 2D transposes
// Otherwise, generate indices and select with +⌜ and ⊏
// SHOULD generate for a cell and virtualize the rest to save space
// COULD decompose axis permutations to use 2D transpose when possible
// COULD convert boolean to integer for some axis reorderings
// SHOULD have a small-subarray transposer using one or a few shuffles
// ⍉⁼𝕩: data movement of ⍉ with different shape logic
// 𝕨⍉⁼𝕩: compute inverse 𝕨, length 1+⌈´𝕨
// Under Transpose supports invertible cases
// SHOULD implement Under with duplicate axes, maybe as Under Select
// ⍉˘𝕩 and k⍉˘𝕩 for number k: convert to 0‿a⍉𝕩
// SHOULD convert ⍉ with rank to a Reorder Axes call
// COULD implement fast ⍉⍟n
#include "../core.h"
#include "../utils/each.h"
#include "../utils/talloc.h"
#include "../builtins.h"
#include "../utils/calls.h"
#ifdef __BMI2__
#include <immintrin.h>
@ -10,25 +46,83 @@
#endif
#endif
#define TRANSPOSE_LOOP( DST, SRC, W, H) PLAINLOOP for(usz y=0;y< H;y++) NOVECTORIZE for(usz x=0;x< W;x++) DST[x*H+y] = SRC[xi++]
#define TRANSPOSE_BLOCK(DST, SRC, BW, BH, W, H) PLAINLOOP for(usz y=0;y<BH;y++) NOVECTORIZE for(usz x=0;x<BW;x++) DST[x*H+y] = SRC[y*W+x]
#define DECL_BASE(T) \
static NOINLINE void transpose_##T(void* rv, void* xv, u64 bw, u64 bh, u64 w, u64 h) { \
T* rp=rv; T* xp=xv; \
PLAINLOOP for(usz y=0;y<bh;y++) NOVECTORIZE for(usz x=0;x<bw;x++) rp[x*h+y] = xp[y*w+x]; \
}
DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64)
#undef DECL_BASE
#if SINGELI_X86_64
#define DECL_BASE(T) \
static NOINLINE void base_transpose_##T(T* rp, T* xp, u64 bw, u64 bh, u64 w, u64 h) { \
TRANSPOSE_BLOCK(rp, xp, bw, bh, w, h); \
}
DECL_BASE(i8) DECL_BASE(i16) DECL_BASE(i32) DECL_BASE(i64)
#undef DECL_BASE
typedef void (*TranspFn)(void*,void*,u64,u64,u64,u64);
#if SINGELI
#define transposeFns simd_transpose
#define SINGELI_FILE transpose
#include "../utils/includeSingeli.h"
#define TRANSPOSE_SIMD(T, DST, SRC, W, H) simd_transpose_##T(DST, SRC, W, H)
#else
#define TRANSPOSE_SIMD(T, DST, SRC, W, H) TRANSPOSE_LOOP(DST, SRC, W, H)
static TranspFn transposeFns[] = {
transpose_i8, transpose_i16, transpose_i32, transpose_i64
};
#endif
extern B rt_transp;
static void transpose_move(void* rv, void* xv, u8 xe, usz w, usz h) {
assert(xe!=el_bit); assert(xe!=el_B);
transposeFns[elWidthLogBits(xe)-3](rv, xv, w, h, w, h);
}
// Return an array with data from x transposed as though it's shape h,w
// Shape of result needs to be set afterwards!
static Arr* transpose_noshape(B* px, usz ia, usz w, usz h) {
B x = *px;
u8 xe = TI(x,elType);
Arr* r;
if (xe==el_B) {
B xf = getFillR(x);
B* xp = TO_BPTR(x);
HArr_p p = m_harrUv(ia); // Debug build complains with harrUp
transpose_move(p.a, xp, el_f64, w, h);
for (usz xi=0; xi<ia; xi++) inc(p.a[xi]); // TODO don't inc when there's a method of freeing a HArr without freeing its elements
NOGC_E;
r=a(qWithFill(p.b, xf));
} else if (xe==el_bit) {
#ifdef __BMI2__
if (h==2) {
u32* x0 = (u32*)bitarr_ptr(x);
u64* rp; r=m_bitarrp(&rp, ia);
Arr* x1o = TI(x,slice)(inc(x),w,w);
u32* x1 = (u32*) ((TyArr*)x1o)->a;
for (usz i=0; i<BIT_N(ia); i++) rp[i] = _pdep_u64(x0[i], 0x5555555555555555) | _pdep_u64(x1[i], 0xAAAAAAAAAAAAAAAA);
mm_free((Value*)x1o);
} else if (w==2) {
u64* xp = bitarr_ptr(x);
u64* r0; r=m_bitarrp(&r0, ia);
TALLOC(u64, r1, BIT_N(h));
for (usz i=0; i<BIT_N(ia); i++) {
u64 v = xp[i];
((u32*)r0)[i] = _pext_u64(v, 0x5555555555555555);
((u32*)r1)[i] = _pext_u64(v, 0xAAAAAAAAAAAAAAAA);
}
bit_cpy(r0, h, r1, 0, h);
TFREE(r1);
} else
#endif
{
*px = x = taga(cpyI8Arr(x)); xe=el_i8;
void* rv = m_tyarrp(&r,elWidth(xe),ia,el2t(xe));
void* xv = tyany_ptr(x);
transpose_move(rv, xv, xe, w, h);
r = (Arr*)cpyBitArr(taga(r));
}
} else {
void* rv = m_tyarrp(&r,elWidth(xe),ia,el2t(xe));
void* xv = tyany_ptr(x);
transpose_move(rv, xv, xe, w, h);
}
return r;
}
B transp_c1(B t, B x) {
if (RARE(isAtm(x))) return m_atomUnit(x);
ur xr = RNK(x);
@ -37,8 +131,7 @@ B transp_c1(B t, B x) {
usz ia = IA(x);
usz* xsh = SH(x);
usz h = xsh[0];
if (ia==0 || h==1) {
no_reorder:;
if (ia==0 || h==1 || h==ia /*w==1*/) {
Arr* r = cpyWithShape(x);
ShArr* sh = m_shArr(xr);
shcpy(sh->a, xsh+1, xr-1);
@ -47,116 +140,300 @@ B transp_c1(B t, B x) {
return taga(r);
}
usz w = xsh[1] * shProd(xsh, 2, xr);
if (w==1) goto no_reorder;
Arr* r;
usz xi = 0;
u8 xe = TI(x,elType);
bool toBit = false;
if (h==2) {
if (xe==el_B) {
B* xp = TO_BPTR(x);
B* x0 = xp; B* x1 = x0+w;
HArr_p rp = m_harrUp(ia);
for (usz i=0; i<w; i++) { rp.a[i*2] = inc(x0[i]); rp.a[i*2+1] = inc(x1[i]); }
NOGC_E;
r = (Arr*) rp.c;
} else {
#ifndef __BMI2__
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; toBit=true; }
void* rp = m_tyarrp(&r,elWidth(xe),ia,el2t(xe));
#else
void* rp = m_tyarrlbp(&r,elWidthLogBits(xe),ia,el2t(xe));
#endif
void* xp = tyany_ptr(x);
switch(xe) { default: UD;
#ifdef __BMI2__
case el_bit:;
u32* x0 = xp;
Arr* x1o = TI(x,slice)(inc(x),w,w);
u32* x1 = (u32*) ((TyArr*)x1o)->a;
for (usz i=0; i<BIT_N(ia); i++) ((u64*)rp)[i] = _pdep_u64(x0[i], 0x5555555555555555) | _pdep_u64(x1[i], 0xAAAAAAAAAAAAAAAA);
mm_free((Value*)x1o);
break;
#endif
case el_i8: case el_c8: { u8* x0=xp; u8* x1=x0+w; for (usz i=0; i<w; i++) { ((u8* )rp)[i*2] = x0[i]; ((u8* )rp)[i*2+1] = x1[i]; } } break;
case el_i16:case el_c16: { u16* x0=xp; u16* x1=x0+w; for (usz i=0; i<w; i++) { ((u16*)rp)[i*2] = x0[i]; ((u16*)rp)[i*2+1] = x1[i]; } } break;
case el_i32:case el_c32: { u32* x0=xp; u32* x1=x0+w; for (usz i=0; i<w; i++) { ((u32*)rp)[i*2] = x0[i]; ((u32*)rp)[i*2+1] = x1[i]; } } break;
case el_f64: { u64* x0=xp; u64* x1=x0+w; for (usz i=0; i<w; i++) { ((u64*)rp)[i*2] = x0[i]; ((u64*)rp)[i*2+1] = x1[i]; } } break;
}
}
} else if (w==2 && xe!=el_B) {
#ifndef __BMI2__
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; toBit=true; }
#endif
void* rp = m_tyarrlbp(&r,elWidthLogBits(xe),ia,el2t(xe));
void* xp = tyany_ptr(x);
switch(xe) { default: UD;
#if __BMI2__
case el_bit:;
u64* r0 = rp; TALLOC(u64, r1, BIT_N(h));
for (usz i=0; i<BIT_N(ia); i++) {
u64 v = ((u64*)xp)[i];
((u32*)r0)[i] = _pext_u64(v, 0x5555555555555555);
((u32*)r1)[i] = _pext_u64(v, 0xAAAAAAAAAAAAAAAA);
}
bit_cpy(r0, h, r1, 0, h);
TFREE(r1);
break;
#endif
case el_i8: case el_c8: { u8* r0=rp; u8* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((u8* )xp)[i*2]; r1[i] = ((u8* )xp)[i*2+1]; } } break;
case el_i16:case el_c16: { u16* r0=rp; u16* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((u16*)xp)[i*2]; r1[i] = ((u16*)xp)[i*2+1]; } } break;
case el_i32:case el_c32: { u32* r0=rp; u32* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((u32*)xp)[i*2]; r1[i] = ((u32*)xp)[i*2+1]; } } break;
case el_f64: { f64* r0=rp; f64* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((f64*)xp)[i*2]; r1[i] = ((f64*)xp)[i*2+1]; } } break;
}
} else {
switch(xe) { default: UD;
case el_bit: x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; toBit=true; // fallthough
case el_i8: case el_c8: { u8* xp=tyany_ptr(x); u8* rp = m_tyarrp(&r,1,ia,el2t(xe)); TRANSPOSE_SIMD( i8, rp, xp, w, h); break; }
case el_i16:case el_c16: { u16* xp=tyany_ptr(x); u16* rp = m_tyarrp(&r,2,ia,el2t(xe)); TRANSPOSE_SIMD(i16, rp, xp, w, h); break; }
case el_i32:case el_c32: { u32* xp=tyany_ptr(x); u32* rp = m_tyarrp(&r,4,ia,el2t(xe)); TRANSPOSE_SIMD(i32, rp, xp, w, h); break; }
case el_f64: { f64* xp=f64any_ptr(x); f64* rp; r=m_f64arrp(&rp,ia); TRANSPOSE_SIMD(i64, rp, xp, w, h); break; }
case el_B: { // can't be bothered to implement a bitarr transpose
B xf = getFillR(x);
B* xp = TO_BPTR(x);
HArr_p p = m_harrUp(ia);
for(usz y=0;y<h;y++) for(usz x=0;x<w;x++) p.a[x*h+y] = inc(xp[xi++]); // TODO inc afterwards, but don't when there's a method of freeing a HArr without freeing its elements
NOGC_E;
usz* rsh = arr_shAlloc((Arr*)p.c, xr);
if (xr==2) {
rsh[0] = w;
rsh[1] = h;
} else {
shcpy(rsh, xsh+1, xr-1);
rsh[xr-1] = h;
}
decG(x); return qWithFill(p.b, xf);
}
}
}
Arr* r = transpose_noshape(&x, ia, w, h);
usz* rsh = arr_shAlloc(r, xr);
if (xr==2) {
rsh[0] = w;
rsh[1] = h;
} else {
shcpy(rsh, xsh+1, xr-1);
rsh[xr-1] = h;
}
decG(x); return taga(toBit? (Arr*)cpyBitArr(taga(r)) : r);
if (xr==2) rsh[0] = w; else shcpy(rsh, SH(x)+1, xr-1);
rsh[xr-1] = h;
decG(x); return taga(r);
}
B transp_c2(B t, B w, B x) { return c2rt(transp, w, x); }
B mul_c2(B,B,B);
B ud_c1(B,B);
B tbl_c2(Md1D*,B,B);
B select_c2(B,B,B);
static void shSet(Arr* ra, ur rr, ShArr* sh) {
if (RARE(rr <= 1)) arr_shVec(ra);
else arr_shSetU(ra, rr, sh);
}
B transp_c2(B t, B w, B x) {
usz wia=1;
if (isArr(w)) {
if (RNK(w)>1) thrM("⍉: 𝕨 must have rank at most 1");
wia = IA(w);
if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); }
}
ur xr;
if (isAtm(x) || (xr=RNK(x))<wia) thrM("⍉: Length of 𝕨 must be at most rank of 𝕩");
// Axis permutation
TALLOC(u8, alloc, xr*(sizeof(ur) + 3*sizeof(usz))); // ur* p, usz* rsh, usz* st, usz* ri
ur* p = (ur*)alloc;
if (isAtm(w)) {
usz a=o2s(w);
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
if (a==xr-1) { TFREE(alloc); return C1(transp, x); }
p[0] = a;
} else {
SGetU(w)
for (usz i=0; i<wia; i++) {
usz a=o2s(GetU(w, i));
if (a>=xr) thrF("⍉: Axis %s does not exist (%i≡=𝕩)", a, xr);
p[i] = a;
}
decG(w);
}
B r;
// Compute shape for the given axes
usz* xsh = SH(x);
usz* rsh = (usz*)(p + xr); // Length xr
usz dup = 0, max = 0, id = 0;
usz no_sh = -(usz)1;
for (usz j=0; j<xr; j++) rsh[j] = no_sh;
for (usz i=0; i<wia; i++) {
ur j=p[i];
usz xl=xsh[i], l=rsh[j];
dup += l!=no_sh;
id += i==j;
max = j>max? j : max;
if (xl<l) rsh[j]=xl;
}
if (id == wia) { r = x; goto ret; }
// Fill in remaining axes and check for missing ones
ur rr = xr-dup;
if (max >= rr) thrF("⍉: Skipped result axis");
if (wia<xr) for (usz j=0, i=wia; j<rr; j++) if (rsh[j]==no_sh) {
p[i] = j;
rsh[j] = xsh[i];
i++;
}
// Create shape object, saving unprocessed result shape
ShArr* sh;
if (LIKELY(rr > 1)) { // Not all duplicates
sh = m_shArr(rr);
shcpy(sh->a, rsh, rr);
}
// Empty result
if (IA(x) == 0) {
Arr* ra = m_fillarrpEmpty(getFillQ(x));
shSet(ra, rr, sh);
decG(x);
r = taga(ra); goto ret;
}
// Add up stride for each axis
ur na = max + 1; // Number of result axes that moved
usz* st = rsh + xr; // Length na
for (usz j=0; j<na; j++) st[j] = 0;
usz csz = shProd(xsh, na+dup, xr);
for (usz i=na+dup, c=csz; i--; ) { st[p[i]]+=c; c*=xsh[i]; }
// Simplify axis structure
// p is unused now; work only on csz, rsh, and st
usz *lp = &csz; usz sz = csz;
usz na0=na; usz* rsh0=rsh; usz* st0=st; rsh+=na0; st+=na0; na=0;
for (usz i=na0; i--; ) {
usz l = rsh0[i]; if (l==1) continue; // Ignore
usz s = st0[i]; if (s==sz) { *lp*=l; sz*=l; continue; } // Combine with lower
na++; *--rsh=l; *--st=s; lp=rsh; sz=l*s;
}
// Turned out trivial
if (na == 0) {
Arr* ra = TI(x,slice)(x, 0, csz);
shSet(ra, rr, sh);
r = taga(ra); goto ret;
}
u8 xe = TI(x,elType);
#define AXIS_LOOP(N_AX, I_INC, DO_INNER) \
ur a0 = N_AX - 1; \
usz* ri = st+na; for (usz i=0; i<a0; i++) ri[i]=0; \
usz l = rsh[a0]; \
for (usz i=0, j0=0;;) { \
/* Hardcode one innermost loop: assume N_AX>=1 */ \
ur a = a0; \
usz str = st[a]; \
for (usz k=0; k<l; k++) { \
usz j=j0+k*str; DO_INNER; \
i += I_INC; \
} \
if (i == ria) break; \
/* Update result index starting with last axis finished */\
while (1) { \
str = st[--a]; \
j0 += str; \
if (LIKELY(++ri[a] < rsh[a])) break; \
ri[a] = 0; \
j0 -= rsh[a] * str; \
} \
}
u8 xlw = elWidthLogBits(xe);
if (csz >= (32*8) >> xlw) { // cell >= 32 bytes
usz ria = csz * shProd(rsh, 0, na);
MAKE_MUT_INIT(rm, ria, xe); MUTG_INIT(rm);
AXIS_LOOP(na, csz, mut_copyG(rm, i, x, j, csz));
Arr* ra = mut_fp(rm);
shSet(ra, rr, sh);
r = withFill(taga(ra), getFillQ(x));
decG(x); goto ret;
}
#undef AXIS_LOOP
if ((csz & (csz-1))==0 && csz<=64>>xlw && csz<<xlw>=8 // CPU-sized cells
&& xe!=el_B && na>=2) {
// If some result axis has stride 1 (guaranteed if dup==0), then it
// corresponds to the last argument axis and we have a strided
// transpose swapping that with the last result axis
usz rai = na-1;
usz xai=rai; while (st[--xai]!=1) if (xai==0) goto skip_2d;
if (rsh[xai]*rsh[rai] < (256*8) >> xlw) goto skip_2d;
TranspFn tran = transposeFns[CTZ(csz<<xlw)-3];
usz rf = shProd(rsh, 0, na);
Arr* ra;
u8* rp = m_tyarrlbp(&ra,xlw,rf*csz,el2t(xe));
u8* xp = tyany_ptr(x);
usz w = rsh[xai]; usz ws = st[rai];
usz h = rsh[rai]; usz hs = shProd(rsh, xai+1, rai) * h;
if (na == 2) {
tran(rp, xp, w, h, ws, hs);
} else {
csz = (csz<<xlw) / 8; // Convert to bytes
usz i_skip = (w-1)*hs*csz;
usz end = rf*csz - i_skip;
ur a0 = na - 1;
if (xlw<3) for (usz i=0; i<na; i++) st[i] >>= 3-xlw;
else if (xlw>3) for (usz i=0; i<na; i++) st[i] <<= xlw-3;
usz* ri = st+na; for (usz i=0; i<a0; i++) ri[i]=0;
for (usz i=0, j=0;;) {
tran(rp+i, xp+j, w, h, ws, hs);
i += h*csz;
if (i == end) break;
for (ur a = a0;;) {
if (--a == xai) { assert(a!=0); i+=i_skip; --a; }
usz str = st[a];
j += str;
if (LIKELY(++ri[a] < rsh[a])) break;
ri[a] = 0;
j -= rsh[a] * str;
}
}
}
shSet(ra, rr, sh);
r = taga(ra);
decG(x); goto ret;
}
skip_2d:;
// Reshape x for selection
ShArr* zsh = m_shArr(2);
zsh->a[0] = IA(x)/csz;
zsh->a[1] = csz;
Arr* z = TI(x,slice)(x, 0, IA(x));
arr_shSetU(z, 2, zsh);
x = taga(z);
// (+⌜´st×⟜↕¨rsh)⊏⥊𝕩
B ind = bi_N;
for (ur k=na; k--; ) {
B v = C2(mul, m_usz(st[k]/csz), C1(ud, m_f64(rsh[k])));
if (q_N(ind)) ind = v;
else ind = M1C2(tbl, add, v, ind);
}
r = C2(select, ind, x);
Arr* ra = cpyWithShape(r); r = taga(ra);
if (rr>1) arr_shReplace(ra, rr, sh);
else { decSh((Value*)ra); arr_shVec(ra); }
ret:;
TFREE(alloc);
return r;
}
B transp_im(B t, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
if (RNK(x)<=2) return transp_c1(t, x);
return def_fn_im(bi_transp, x);
ur xr = RNK(x);
if (xr<=1) return x;
usz ia = IA(x);
usz* xsh = SH(x);
usz w = xsh[xr-1];
if (ia==0 || w==1 || w==ia /*h==1*/) {
Arr* r = cpyWithShape(x);
ShArr* sh = m_shArr(xr);
sh->a[0] = w;
shcpy(sh->a+1, xsh, xr-1);
arr_shReplace(r, xr, sh);
return taga(r);
}
usz h = xsh[0] * shProd(xsh, 1, xr-1);
Arr* r = transpose_noshape(&x, ia, w, h);
usz* rsh = arr_shAlloc(r, xr);
rsh[0] = w;
if (xr==2) rsh[1] = h; else shcpy(rsh+1, SH(x), xr-1);
decG(x); return taga(r);
}
B transp_uc1(B t, B o, B x) { return transp_im(m_f64(0), c1(o, transp_c1(t, x))); }
B transp_uc1(B t, B o, B x) {
return transp_im(m_f64(0), c1(o, transp_c1(t, x)));
}
// Consumes w; return bi_N if w contained duplicates
static B invert_transp_w(B w, ur xr) {
if (isAtm(w)) {
if (xr<1) thrM("⍉⁼: Length of 𝕨 must be at most rank of 𝕩");
usz a=o2s(w);
if (a>=xr) thrF("⍉⁼: Axis %s does not exist (%i≡=𝕩)", a, xr);
i32* wp; w = m_i32arrv(&wp, a);
PLAINLOOP for (usz i=0; i<a; i++) wp[i] = i+1;
} else {
if (RNK(w)>1) thrM("⍉⁼: 𝕨 must have rank at most 1");
usz wia = IA(w);
if (wia==0) return w;
if (xr<wia) thrM("⍉⁼: Length of 𝕨 must be at most rank of 𝕩");
SGetU(w)
TALLOC(ur, p, xr);
for (usz i=0; i<xr; i++) p[i]=xr;
usz max = 0;
for (usz i=0; i<wia; i++) {
usz a=o2s(GetU(w, i));
if (a>=xr) thrF("⍉⁼: Axis %s does not exist (%i≡=𝕩)", a, xr);
if (p[a]!=xr) { TFREE(p); decG(w); return bi_N; } // Handled by caller
max = a>max? a : max;
p[a] = i;
}
decG(w);
usz n = max+1;
i32* wp; w = m_i32arrv(&wp, n);
for (usz i=0, j=wia; i<n; i++) wp[i] = p[i]<xr? p[i] : j++;
TFREE(p);
}
return w;
}
B transp_ix(B t, B w, B x) {
if (isAtm(x)) thrM("⍉⁼: 𝕩 must not be an atom");
w = invert_transp_w(w, RNK(x));
if (q_N(w)) thrM("⍉⁼: Duplicate axes");
return C2(transp, w, x);
}
B transp_ucw(B t, B o, B w, B x) {
B wi = invert_transp_w(inc(w), isAtm(x)? 0 : RNK(x));
if (q_N(wi)) return def_fn_ucw(t, o, w, x); // Duplicate axes
return C2(transp, wi, c1(o, C2(transp, w, x)));
}
void transp_init(void) {
c(BFn,bi_transp)->uc1 = transp_uc1;
c(BFn,bi_transp)->im = transp_im;
c(BFn,bi_transp)->ix = transp_ix;
c(BFn,bi_transp)->uc1 = transp_uc1;
c(BFn,bi_transp)->ucw = transp_ucw;
}

View File

@ -107,7 +107,7 @@ B comp_currSrc;
B comp_currRe; // ⟨REPL mode ⋄ scope ⋄ compiler ⋄ runtime ⋄ glyphs ⋄ sysval names ⋄ sysval values⟩
B rt_undo, rt_select, rt_slash, rt_insert, rt_depth,
rt_group, rt_under, rt_find, rt_transp;
rt_group, rt_under, rt_find;
Block* load_compObj(B x, B src, B path, Scope* sc) { // consumes x,src
SGet(x)
usz xia = IA(x);
@ -442,7 +442,6 @@ void load_init() { // very last init function
rt_group = Get(rtObjRaw, n_group );
rt_under = Get(rtObjRaw, n_under );
rt_find = Get(rtObjRaw, n_find );
rt_transp = Get(rtObjRaw, n_transp );
rt_depth = Get(rtObjRaw, n_depth );
rt_insert = Get(rtObjRaw, n_insert );
@ -485,7 +484,7 @@ void load_init() { // very last init function
}
load_rtObj = frtObj;
load_compArg = m_hVec2(load_rtObj, incG(bi_sys));
rt_select=rt_slash=rt_group=rt_find=rt_transp=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
rt_select=rt_slash=rt_group=rt_find=rt_invFnReg=rt_invFnSwap = incByG(bi_invalidFn, 7);
rt_undo=rt_insert = incByG(bi_invalidMd1, 2);
rt_under=rt_depth = incByG(bi_invalidMd2, 2);
rt_invFnRegFn=rt_invFnSwapFn = invalidFn_c1;
@ -497,7 +496,6 @@ void load_init() { // very last init function
gc_add(rt_group);
gc_add(rt_under);
gc_add(rt_find);
gc_add(rt_transp);
gc_add(rt_depth);
gc_add(rt_insert);

View File

@ -1,12 +1,3 @@
def w256{T} = 0
def w256{T & isvec{T}} = width{T}==256
def w256{T,w} = 0
def w256{T,w & w256{T}} = elwidth{T}==w
def w256i = genchk{w256, {T} => isint{T}}
def w256s = genchk{w256, {T} => issigned{T}}
def w256u = genchk{w256, {T} => isunsigned{T}}
def w256f = genchk{w256, {T} => isfloat{T}}
def v2i{x:T & w256{T}} = [32]u8 ~~ x # for compact casting for the annoying intrinsic type system
def v2f{x:T & w256{T}} = [8]f32 ~~ x
def v2d{x:T & w256{T}} = [4]f64 ~~ x

View File

@ -11,15 +11,6 @@ def exportT{name, fs} = { v:*type{tupsel{0,fs}} = fs; export{name, v} }
def elwidth{T} = width{eltype{T}}
def genchk{B, F} = {
def r{T} = 0
def r{T & B{T}} = F{eltype{T}}
def r{T,w} = 0
def r{T,w & B{T}} = F{eltype{T}} & (elwidth{T}==w)
def r{T & ~isvec{T}} = 0
r
}
# ceiling divide
def cdiv{a,b} = (a+b-1)/b
@ -61,6 +52,34 @@ def anyInt{x} = 0
def anyInt{x & knum{x}} = (x>>0) == x
def anyInt{x & isreg{x}|isconst{x}} = isint{x}
# vector width/type checks
def w128{T} = 0
def w128{T & isvec{T}} = width{T}==128
def w128{T,w} = 0
def w128{T,w & w128{T}} = elwidth{T}==w
def w256{T} = 0
def w256{T & isvec{T}} = width{T}==256
def w256{T,w} = 0
def w256{T,w & w256{T}} = elwidth{T}==w
# width+type checks
def genchk{B, F} = {
def r{T} = 0
def r{T & B{T}} = F{eltype{T}}
def r{T,w} = 0
def r{T,w & B{T}} = F{eltype{T}} & (elwidth{T}==w)
def r{T & ~isvec{T}} = 0
r
}
def w128i = genchk{w128, {T} => isint{T}}
def w128s = genchk{w128, {T} => issigned{T}}
def w128u = genchk{w128, {T} => isunsigned{T}}
def w128f = genchk{w128, {T} => isfloat{T}}
def w256i = genchk{w256, {T} => isint{T}}
def w256s = genchk{w256, {T} => issigned{T}}
def w256u = genchk{w256, {T} => isunsigned{T}}
def w256f = genchk{w256, {T} => isfloat{T}}
def trunc{T, x:U & isint{T} & isint{U} & T<=U} = emit{T, '', x}
def trunc{T, x & knum{x}} = cast{T, x}

View File

@ -1,12 +1,3 @@
def w128{T} = 0
def w128{T & isvec{T}} = width{T}==128
def w128{T,w} = 0
def w128{T,w & w128{T}} = elwidth{T}==w
def w128i = genchk{w128, {T} => isint{T}}
def w128s = genchk{w128, {T} => issigned{T}}
def w128u = genchk{w128, {T} => isunsigned{T}}
def w128f = genchk{w128, {T} => isfloat{T}}
def v2i{x:T & w128{T}} = [16]u8 ~~ x # for compact casting for the annoying intrinsic type system
def v2f{x:T & w128{T}} = [4]f32 ~~ x
def v2d{x:T & w128{T}} = [2]f64 ~~ x

View File

@ -65,16 +65,8 @@ def for_mult_max{k, m}{vars,begin,end,block} = {
}
}
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
# Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
def call_base{...a} = emit{void, merge{'base_transpose_',ts}, ...a, w, h}
rp:*T = *T~~r0
xp:*T = *T~~x0
if (w<k or h<k) { call_base{rp, xp, w, h}; return{} }
def at{x,y} = tup{xp + y*w + x, rp + x*h + y}
def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
def at{x,y} = tup{xp + y*ws + x, rp + x*hs + y}
# Cache line info
def line_bytes = 64
@ -87,7 +79,7 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
we := w; if (use_overlap{wo}) we += k - wo
wm := w - k
if (line_elts > 2*k or h&(line_elts-1) != 0) {
if (line_elts > 2*k or h&(line_elts-1) != 0 or h != hs) {
ho := h%k
# Effective height, like we for w
he := h; if (use_overlap{ho}) he += k - ho
@ -96,7 +88,7 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
# Main transpose
@for_mult_max{kh, h-kh} (y to he) {
@for_mult_max{k, wm} (x to we) {
kernel{...at{x,y}, k, kh, w, h}
kernel{...at{x,y}, k, kh, ws, hs}
}
}
# Half-row(s) for non-square i16 case
@ -106,12 +98,12 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
@for (yi to n) {
y:u64 = 0; if (yi == n-1) y = h - e
@for_mult_max{k, wm} (x to we) {
kernel{...at{x,y}, k, k, w, h}
kernel{...at{x,y}, k, k, ws, hs}
}
}
}
# Base transpose used if overlap wasn't
if (ho!=0 and he==h) { hs := h-ho; call_base{rp+hs, xp+w*hs, w, ho} }
if (ho!=0 and he==h) { hd := h-ho; call_base{rp+hd, xp+ws*hd, w, ho} }
} else {
# Result rows are aligned with each other so it's possible to
# write a full cache line at a time
@ -125,49 +117,67 @@ fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64) : void = {
def vt{i} = transpose_square{VT, k, each{loadx, k*i + iota{k}}}
each{tup, ...each{vt, iota{line_vecs}}}
}
ro := tail{6, -u64~~r0} / (width{T}/8) # Offset to align within cache line; assume elt-aligned
wh := w*h
ro := tail{6, -u64~~rp} / (width{T}/8) # Offset to align within cache line; assume elt-aligned
wh := ws*h
yn := h
if (ro != 0) {
ra := line_elts - ro
y := h - ra
rpe := rp + y + (w-1)*h # Cache aligned
rpe := rp + y + (w-1)*hs # Cache aligned
# Part of first and last result row aren't covered by the split loop
def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, w*i}}
def trtail{dst, src, len} = @for (i to len) store{dst, i, load{src, ws*i}}
trtail{rp, xp, ro}
trtail{rpe, xp + y*w + w-1, ra}
trtail{rpe, xp + y*ws + w-1, ra}
# Transpose first few rows and last few rows together
@for_mult_max{k, wm} (x to we) {
{xpo,rpo} := at{x, y}
o := w*y + x
o := ws*y + x
def loadx{_} = {
l:=load{*VT~~(xp+o)}
o+=w; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1
o+=ws; if (o>wh-k) o -= wh-1 # Jump from last source row to first, shifting right 1
l
}
def rls = get_lines{loadx} # 4 rows of 2 vectors each
each{{i,v} => {p:=rpo+i*h; if (i<3 or p<rpe) store_line{*VT~~p, v}}, iota{k}, rls}
each{{i,v} => {p:=rpo+i*hs; if (i<3 or p<rpe) store_line{*VT~~p, v}}, iota{k}, rls}
}
--yn # One strip handled
}
@for_mult{line_elts} (y0 to yn) { y := y0 + ro
@for_mult_max{k, wm} (x to we) {
{xpo,rpo} := at{x, y}
def rls = get_lines{{i} => load{*VT~~(xpo+i*w), 0}}
each{{i,v} => store_line{*VT~~(rpo+i*h), v}, iota{k}, rls}
def rls = get_lines{{i} => load{*VT~~(xpo+i*ws), 0}}
each{{i,v} => store_line{*VT~~(rpo+i*hs), v}, iota{k}, rls}
}
}
}
if (we==w) @for(ws from w-wo to w) {
xpo:=xp+ws; rpo:=rp+h*ws
@for (i to h) store{rpo, i, load{xpo, w*i}}
if (we==w) @for(wd from w-wo to w) {
xpo:=xp+wd; rpo:=rp+hs*wd
@for (i to h) store{rpo, i, load{xpo, ws*i}}
}
}
fn transpose{T, k, kh}(r0:*void, x0:*void, w:u64, h:u64, ws:u64, hs:u64) : void = {
# Scalar transpose defined in C
def ts = if (T==i8) 'i8' else if (T==i16) 'i16' else if (T==i32) 'i32' else 'i64'
def call_base{...a} = emit{void, merge{'transpose_',ts}, ...a, ws, hs}
rp:*T = *T~~r0
xp:*T = *T~~x0
if (hasarch{'X86_64'} and w>=k and h>=k) {
transpose_with_kernel{T, k, kh, call_base, rp, xp, w, h, ws, hs}
} else {
if (h==2 and h==hs) @for (x0 in xp, x1 in xp+ws over i to w) { store{rp, i*2, x0}; store{rp, i*2+1, x1} }
else if (w==2 and w==ws) @for (r0 in rp, r1 in rp+hs over i to h) { r0 = load{xp, i*2}; r1 = load{xp, i*2+1} }
else call_base{rp, xp, w, h}
}
}
def transpose{T, k} = transpose{T, k, k}
export{'simd_transpose_i8', transpose{i8 , 16}}
export{'simd_transpose_i16', transpose{i16, 8, 16}}
export{'simd_transpose_i32', transpose{i32, 8}}
export{'simd_transpose_i64', transpose{i64, 4}}
exportT{'simd_transpose', tup{
transpose{i8 , 16},
transpose{i16, 8, 16},
transpose{i32, 8},
transpose{i64, 4}
}}