fast transpose on shape n‿2 𝕩

This commit is contained in:
dzaima 2023-03-03 18:53:18 +02:00
parent 76d26db4c2
commit 717074a2f8
2 changed files with 27 additions and 3 deletions

View File

@ -1296,7 +1296,7 @@ B transp_c1(B t, B x) {
r = (Arr*) rp.c;
} else {
#ifndef __BMI2__
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; }
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; toBit=true; }
void* rp = m_tyarrp(&r,elWidth(xe),ia,el2t(xe));
#else
void* rp = m_tyarrlbp(&r,elWidthLogBits(xe),ia,el2t(xe));
@ -1318,6 +1318,30 @@ B transp_c1(B t, B x) {
case el_f64: { u64* x0=xp; u64* x1=x0+w; for (usz i=0; i<w; i++) { ((u64*)rp)[i*2] = x0[i]; ((u64*)rp)[i*2+1] = x1[i]; } } break;
}
}
} else if (w==2 && xe!=el_B) {
#ifndef __BMI2__
if (xe==el_bit) { x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; toBit=true; }
#endif
void* rp = m_tyarrlbp(&r,elWidthLogBits(xe),ia,el2t(xe));
void* xp = tyany_ptr(x);
switch(xe) { default: UD;
#if __BMI2__
case el_bit:
u64* r0 = rp; TALLOC(u64, r1, BIT_N(h));
for (usz i=0; i<BIT_N(ia); i++) {
u64 v = ((u64*)xp)[i];
((u32*)r0)[i] = _pext_u64(v, 0x5555555555555555);
((u32*)r1)[i] = _pext_u64(v, 0xAAAAAAAAAAAAAAAA);
}
bit_cpy(r0, h, r1, 0, h);
TFREE(r1);
break;
#endif
case el_i8: case el_c8: { u8* r0=rp; u8* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((u8* )xp)[i*2]; r1[i] = ((u8* )xp)[i*2+1]; } } break;
case el_i16:case el_c16: { u16* r0=rp; u16* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((u16*)xp)[i*2]; r1[i] = ((u16*)xp)[i*2+1]; } } break;
case el_i32:case el_c32: { u32* r0=rp; u32* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((u32*)xp)[i*2]; r1[i] = ((u32*)xp)[i*2+1]; } } break;
case el_f64: { f64* r0=rp; f64* r1=r0+h; for (usz i=0; i<h; i++) { r0[i] = ((f64*)xp)[i*2]; r1[i] = ((f64*)xp)[i*2+1]; } } break;
}
} else {
switch(xe) { default: UD;
case el_bit: x = taga(cpyI8Arr(x)); xsh=SH(x); xe=el_i8; toBit=true; // fallthough

View File

@ -62,10 +62,10 @@ void tailVerifyAlloc(void* ptr, u64 filled, i64 logAlloc, u8 type) {
if (type==t_talloc) ((u64*)((u8*)ptr + end - 8))[0] = filled-8; // -8 because TALLOCP does a +8
}
void verifyEnd(void* ptr, u64 sz, u64 start, u64 end) {
if (end+64>sz) { printf("Bad used range: "N64u".."N64u", allocation size "N64u"\n", start, end, sz); exit(1); }
if (end+64>sz) { printf("Bad used range: "N64u".."N64u", allocation size "N64u"\n", start, end, sz); __builtin_trap(); }
}
void tailVerifyReinit(void* ptr, u64 filled, u64 end) {
if(filled>end || filled<=8) { printf("Bad reinit arguments: "N64u".."N64u"\n", filled, end); exit(1); }
if(filled>end || filled<=8) { printf("Bad reinit arguments: "N64u".."N64u"\n", filled, end); __builtin_trap(); }
verifyEnd(ptr, mm_size(ptr), filled, end);
tailVerifyInit(ptr, filled, end, mm_size(ptr));
}