Separate where into its own function and split different types completely

This commit is contained in:
Marshall Lochbaum 2022-09-11 08:31:20 -04:00
parent 638121c054
commit 1f40d36712

View File

@ -109,6 +109,95 @@
for (u64 v=X[i]; v; v&=v-1) R[j++] = i*64 + CTZ(v); \
} while (0)
static B where(B x, usz xia, u64 s) {
B r;
u64* xp = bitarr_ptr(x);
usz q=xia%64; if (q) xp[xia/64] &= ((u64)1<<q) - 1;
if (xia <= 128) {
#if SINGELI && defined(__BMI2__)
i8* rp = m_tyarrvO(&r, 1, s, t_i8arr, 8);
bmipopc_1slash8(xp, rp, xia);
#else
i8* rp; r=m_i8arrv(&rp,s); WHERE_SPARSE(xp,rp,s);
#endif
} else if (xia <= 32768) {
#if SINGELI && defined(__BMI2__)
if (s >= xia/16) {
i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 16);
bmipopc_1slash16(xp, rp, xia);
}
#else
if (s >= xia/2) {
i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 2);
for (usz i=0; i<(xia+7)/8; i++) {
u8 v = ((u8*)xp)[i];
for (usz k=0; k<8; k++) { *rp=8*i+k; rp+=v&1; v>>=1; }
}
}
#endif
else {
i16* rp; r=m_i16arrv(&rp,s); WHERE_SPARSE(xp,rp,s);
}
} else {
assert(xia <= (usz)I32_MAX+1);
#if SINGELI && defined(__BMI2__)
i32* rp; r = m_i32arrv(&rp, s);
#else
i32* rp = m_tyarrvO(&r, 4, s, t_i32arr, 4);
#endif
usz b = 1<<11; // Maximum allowed for branchless sparse method
TALLOC(i16, buf, b);
i32* rq=rp; usz i=0;
for (; i<xia; i+=b) {
usz bs;
if (b>xia-i) {
b = xia-i;
bs = s-(rq-rp);
} else {
bs = bit_sum(xp,b);
}
#if SINGELI && defined(__BMI2__)
if (bs >= b/8+b/16) {
bmipopc_1slash16(xp, buf, b);
for (usz j=0; j<bs; j++) rq[j] = i+buf[j];
}
#else
if (bs >= b/2) {
for (usz ii=0; ii<(b+7)/8; ii++) {
u8 v = ((u8*)xp)[ii];
i32* rs=rq;
for (usz k=0; k<8; k++) { *rs=i+8*ii+k; rs+=v&1; v>>=1; }
}
}
#endif
else if (bs >= b/256) { // Branchless sparse
for (usz j=0; j<bs; j++) rq[j]=0;
u32 top = 1<<24;
for (usz i=0, j=0; i<(b+63)/64; i++) {
u64 u=xp[i], p;
p=(u32)u&(top-1); rq[j]+=(2*top)|p; j+=POPC(p); u>>=24;
p=(u32)u&(top-1); rq[j]+=(3*top)|p; j+=POPC(p); u>>=24;
p=(u32)u ; rq[j]+=(3*top)|p; j+=POPC(p);
}
u64 t=((u64)i<<21)-2*top;
for (usz j=0; j<bs; j++) {
t += (u32)rq[j];
rq[j] = 8*(t>>24) + CTZ((u32)t);
t &= t-1;
}
} else { // Branched very sparse
for (usz ii=i/64, j=0; j<bs; ii++) {
for (u64 v=xp[ii-i/64]; RARE(v); v&=v-1) rq[j++] = ii*64 + CTZ(v);
}
}
rq+= bs;
xp+= b/64;
}
TFREE(buf);
}
return r;
}
extern B rt_slash;
B slash_c1(B t, B x) {
if (RARE(isAtm(x)) || RARE(RNK(x)!=1)) thrF("/: Argument must have rank 1 (%H ≡ ≢𝕩)", x);
@ -130,87 +219,7 @@ B slash_c1(B t, B x) {
B r;
u8 xe = TI(x,elType);
if (xe==el_bit) {
u64* xp = bitarr_ptr(x);
if (xia > 32768) {
usz q=xia%64; if (q) xp[xia/64] &= ((u64)1<<q) - 1;
#if SINGELI && defined(__BMI2__)
i32* rp; r = m_i32arrv(&rp, s);
#else
i32* rp = m_tyarrvO(&r, 4, s, t_i32arr, 4);
#endif
usz b = 1<<11; // Maximum allowed for branchless sparse method
TALLOC(i16, buf, b);
i32* rq=rp; usz i=0;
for (; i<xia; i+=b) {
usz bs;
if (b>xia-i) {
b = xia-i;
bs = s-(rq-rp);
} else {
bs = bit_sum(xp,b);
}
#if SINGELI && defined(__BMI2__)
if (bs >= b/8+b/16) {
bmipopc_1slash16(xp, buf, b);
for (usz j=0; j<bs; j++) rq[j] = i+buf[j];
#else
if (bs >= b/2) {
for (usz ii=0; ii<(b+7)/8; ii++) {
u8 v = ((u8*)xp)[ii];
i32* rs=rq;
for (usz k=0; k<8; k++) { *rs=i+8*ii+k; rs+=v&1; v>>=1; }
}
#endif
} else if (bs >= b/256) { // Branchless sparse
for (usz j=0; j<bs; j++) rq[j]=0;
u32 top = 1<<24;
for (usz i=0, j=0; i<(b+63)/64; i++) {
u64 u=xp[i], p;
p=(u32)u&(top-1); rq[j]+=(2*top)|p; j+=POPC(p); u>>=24;
p=(u32)u&(top-1); rq[j]+=(3*top)|p; j+=POPC(p); u>>=24;
p=(u32)u ; rq[j]+=(3*top)|p; j+=POPC(p);
}
u64 t=((u64)i<<21)-2*top;
for (usz j=0; j<bs; j++) {
t += (u32)rq[j];
rq[j] = 8*(t>>24) + CTZ((u32)t);
t &= t-1;
}
} else { // Branched very sparse
for (usz ii=i/64, j=0; j<bs; ii++) {
for (u64 v=xp[ii-i/64]; RARE(v); v&=v-1) rq[j++] = ii*64 + CTZ(v);
}
}
rq+= bs;
xp+= b/64;
}
TFREE(buf);
} else
// Sparse method with CTZ
#if SINGELI && defined(__BMI2__)
if (xia>128 && s < xia/16) {
#else
if (xia<=128 || s < xia/2) {
#endif
usz q=xia%64; if (q) xp[xia/64] &= ((u64)1<<q) - 1;
#define WHERE(T) T* rp; r=m_##T##arrv(&rp,s); WHERE_SPARSE(xp,rp,s);
if (0&&xia<=32768) { WHERE(i16) } else { WHERE(i32) }
#undef WHERE
} else
#if SINGELI && defined(__BMI2__)
if (xia<=128) { i8* rp = m_tyarrvO(&r, 1, s, t_i8arr , 8); bmipopc_1slash8 (xp, rp, xia); }
else { i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 16); bmipopc_1slash16(xp, rp, xia); }
#else
{
i32* rp = m_tyarrvO(&r, 4, s, t_i32arr, 4);
u8* x8 = (u8*)xp;
u8 q=xia%8; if (q) x8[xia/8] &= (1<<q)-1;
for (usz i=0; i<(xia+7)/8; i++) {
u8 v = x8[i];
for (usz k=0; k<8; k++) { *rp=8*i+k; rp+= v&1; v>>=1; }
}
}
#endif
r = where(x, xia, s);
} else {
i32* rp; r = m_i32arrv(&rp, s);
if (xe==el_i8) {