use BMI2 for widening to <32-bit cells

This commit is contained in:
dzaima 2023-02-10 23:03:46 +02:00
parent 53737ab3fb
commit fd1178bc0e

View File

@ -8,6 +8,10 @@
#define bitselFns simd_bitsel
#endif
#if defined(__BMI2__) && !SLOW_PDEP
#define FAST_PDEP 1
#endif
NOINLINE Arr* allZeroes(usz ia) { u64* rp; Arr* r = m_bitarrp(&rp, ia); for (usz i = 0; i < BIT_N(ia); i++) rp[i] = 0; return r; }
NOINLINE Arr* allOnes (usz ia) { u64* rp; Arr* r = m_bitarrp(&rp, ia); for (usz i = 0; i < BIT_N(ia); i++) rp[i] = ~0ULL; return r; }
@ -137,13 +141,33 @@ static NOINLINE B zeroPadToCellBits0(B x, usz lr, usz cam, usz pcsz, usz ncsz) {
// TODO widen 8/16-bit cells to 16/32 via cpyC(16|32)Arr
if (ncsz<=64 && (ncsz&(ncsz-1)) == 0) {
u64 msk = (1ull<<pcsz)-1;
switch(ncsz) { default: UD;
case 8: for (ux i=0; i<cam; i++) ((u8* )rp)[i] = rbuu64(xp, i*pcsz)&msk; break;
case 16: for (ux i=0; i<cam; i++) ((u16*)rp)[i] = rbuu64(xp, i*pcsz)&msk; break;
case 32: for (ux i=0; i<cam; i++) ((u32*)rp)[i] = rbuu64(xp, i*pcsz)&msk; break;
case 64: for (ux i=0; i<cam; i++) ((u64*)rp)[i] = rbuu64(xp, i*pcsz)&msk; break;
}
u64 tmsk = (1ull<<pcsz)-1;
#if FAST_PDEP
if (ncsz<32) {
assert(ncsz==8 || ncsz==16);
bool c8 = ncsz==8;
u64 msk0 = tmsk * (c8? 0x0101010101010101 : 0x0001000100010001);
ux am = c8? cam/8 : cam/4;
u32 count = POPC(msk0);
// printf("widen base %04lx %016lx count=%d am=%zu\n", tmsk, msk0, count, am);
for (ux i=0; i<am; i++) { *(u64*)rp = _pdep_u64(rbuu64(xp, i*count), msk0); rp++; }
u32 tb = c8? cam&7 : (cam&3)<<1;
if (tb) {
u64 msk1 = msk0 & ((1ull<<tb*8)-1);
// printf("widen tail %4d %016lx count=%d\n", tb, msk1, POPC(msk1));
*(u64*)rp = _pdep_u64(rbuu64(xp, am*count), msk1);
}
}
else if (ncsz==32) for (ux i=0; i<cam; i++) ((u32*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk;
else for (ux i=0; i<cam; i++) ((u64*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk;
#else
switch(ncsz) { default: UD;
case 8: for (ux i=0; i<cam; i++) ((u8* )rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
case 16: for (ux i=0; i<cam; i++) ((u16*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
case 32: for (ux i=0; i<cam; i++) ((u32*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
case 64: for (ux i=0; i<cam; i++) ((u64*)rp)[i] = rbuu64(xp, i*pcsz)&tmsk; break;
}
#endif
} else {
assert((ncsz&63) == 0 && ncsz-pcsz < 64 && (pcsz&63) != 0);
ux pfu64 = pcsz>>6; // previous full u64 count in cell
@ -172,10 +196,6 @@ NOINLINE B widenBitArr(B x, ur axis) {
return zeroPadToCellBits0(x, axis, shProd(SH(x), 0, axis), pcsz, ncsz);
}
#if defined(__BMI2__) && !SLOW_PDEP
#define FAST_PDEP 1
#endif
B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bits to be dropped are zero, origCellBits is a multiple of 8, and that there's at most 63 padding bits
if (TI(x,elType)!=el_bit) return taga(cpyBitArr(x));
@ -213,12 +233,12 @@ B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bi
u64 msk0 = tmsk * (c8? 0x0101010101010101 : 0x0001000100010001);
ux am = c8? cam/8 : cam/4;
u32 count = POPC(msk0);
// printf("base %04lx %016lx count=%d am=%zu\n", tmsk, msk0, count, am);
// printf("narrow base %04lx %016lx count=%d am=%zu\n", tmsk, msk0, count, am);
for (ux i=0; i<am; i++) { ab_add(&ab, _pext_u64(*(u64*)xp, msk0), count); xp+= 8; }
u32 tb = c8? cam&7 : (cam&3)<<1;
if (tb) {
u64 msk1 = msk0 & ((1ull<<tb*8)-1);
// printf("tail %4d %016lx count=%d\n", tb, msk1, POPC(msk1));
// printf("narrow tail %4d %016lx count=%d\n", tb, msk1, POPC(msk1));
ab_add(&ab, _pext_u64(*(u64*)xp, msk1), POPC(msk1));
}
}