pad input cells when needed for for ∊𝕩, ⊐𝕩, ⊒𝕩

This commit is contained in:
dzaima 2023-02-13 00:17:19 +02:00
parent fd1178bc0e
commit f700a3e150
3 changed files with 35 additions and 11 deletions

View File

@ -91,6 +91,13 @@ static bool canCompare64_norm(B x, usz n) {
}
return 1;
}
static bool shouldWidenBitarr(B x, usz csz) { // assumes cells won't anymore have sizes of 0, 8, or 16 bits
u8 xe = TI(x,elType);
ux bcsz = ((ux)csz)<<elWidthLogBits(xe);
assert(csz!=0 && bcsz!=8 && bcsz!=16);
if (bcsz<64 && bcsz!=32) { assert(xe!=el_B); return true; } // not el_B because csz>0 → csz*sizeof(B) >= 64
return false;
}
#define GRADE_UD(U,D) U
#include "radix.h"
@ -213,10 +220,11 @@ B memberOf_c1(B t, B x) {
u64 n = *SH(x);
if (n<=1) { decG(x); return n ? taga(arr_shVec(allOnes(1))) : emptyIVec(); }
usz csz = arr_csz(x);
u8 lw = cellWidthLog(x);
void* xv = tyany_ptr(x);
if (lw == 0) {
usz i = bit_find(xv, n, 1 &~ *(u64*)xv); decG(x);
if (lw==0 || csz==0) {
usz i = csz==0? n : bit_find(xv, n, 1 &~ *(u64*)xv); decG(x);
B r = taga(arr_shVec(allZeroes(n)));
u64* rp = tyany_ptr(r);
rp[0]=1; if (i<n) bitp_set(rp, i, 1);
@ -287,7 +295,10 @@ B memberOf_c1(B t, B x) {
#undef HASHTAB
#undef BRUTE
if (RNK(x)>1) x = toCells(x);
if (RNK(x)>1) {
if (shouldWidenBitarr(x, csz)) return C1(memberOf, widenBitArr(x, 1));
x = toCells(x);
}
u64* rp; B r = m_bitarrv(&rp, n);
H_Sb* set = m_Sb(64);
SGetU(x)
@ -302,6 +313,8 @@ B count_c1(B t, B x) {
if (n<=1) { decG(x); return n ? taga(arr_shVec(allZeroes(1))) : emptyIVec(); }
if (n>(usz)I32_MAX+1) thrM("⊒: Argument length >2⋆31 not supported");
usz csz = arr_csz(x);
if (csz==0) { decG(x); return C1(ud, m_f64(n)); }
u8 lw = cellWidthLog(x);
if (lw==0) {
u64* xp = bitarr_ptr(x);
@ -394,7 +407,10 @@ B count_c1(B t, B x) {
#undef HASHTAB
#undef BRUTE
if (RNK(x)>1) x = toCells(x);
if (RNK(x)>1) {
if (shouldWidenBitarr(x, csz)) return C1(count, widenBitArr(x, 1));
x = toCells(x);
}
i32* rp; B r = m_i32arrv(&rp, n);
H_b2i* map = m_b2i(64);
SGetU(x)
@ -413,9 +429,11 @@ static B reduceI32WidthBelow(B r, usz after) {
B indexOf_c1(B t, B x) {
if (isAtm(x) || RNK(x)==0) thrM("⊐: 𝕩 cannot have rank 0");
u64 n = *SH(x);
if (n<=1) { decG(x); return n ? taga(arr_shVec(allZeroes(1))) : emptyIVec(); }
if (n<=1) { zeroRes: decG(x); return n? taga(arr_shVec(allZeroes(n))) : emptyIVec(); }
if (n>(usz)I32_MAX+1) thrM("⊐: Argument length >2⋆31 not supported");
usz csz = arr_csz(x);
if (csz==0) goto zeroRes;
u8 lw = cellWidthLog(x);
void* xv = tyany_ptr(x);
if (lw == 0) {
@ -496,7 +514,10 @@ B indexOf_c1(B t, B x) {
#undef BRUTE
#undef DOTAB
if (RNK(x)>1) x = toCells(x);
if (RNK(x)>1) {
if (shouldWidenBitarr(x, csz)) return C1(indexOf, widenBitArr(x, 1));
x = toCells(x);
}
i32* rp; B r = m_i32arrv(&rp, n);
H_b2i* map = m_b2i(64);
SGetU(x)

View File

@ -12,13 +12,15 @@ u8 elTypeWidth[] = {
[el_i8 ] = 1, [el_c8 ] = 1,
[el_i16] = 2, [el_c16] = 2,
[el_i32] = 4, [el_c32] = 4,
[el_bit] = 0, [el_f64] = 8
[el_bit] = 0, [el_f64] = 8,
[el_B] = 8
};
u8 elTypeWidthLogBits[] = {
[el_i8 ] = 3, [el_c8 ] = 3,
[el_i16] = 4, [el_c16] = 4,
[el_i32] = 5, [el_c32] = 5,
[el_bit] = 0, [el_f64] = 6
[el_bit] = 0, [el_f64] = 6,
[el_B] = 6
};
u8 arrTypeWidthLog[] = {
[t_bitarr]=99,

View File

@ -182,8 +182,9 @@ static NOINLINE B zeroPadToCellBits0(B x, usz lr, usz cam, usz pcsz, usz ncsz) {
return taga(r);
}
NOINLINE B widenBitArr(B x, ur axis) {
assert(isArr(x) && TI(x,elType)==el_bit && axis>=1 && RNK(x)>=axis);
usz pcsz = shProd(SH(x), axis, RNK(x));
assert(isArr(x) && TI(x,elType)!=el_B && axis>=1 && RNK(x)>=axis);
usz pcsz = shProd(SH(x), axis, RNK(x))<<elWidthLogBits(TI(x,elType));
assert(pcsz!=0);
usz ncsz;
if (pcsz<=8) ncsz = 8;
else if (pcsz<=16) ncsz = 16;
@ -202,7 +203,7 @@ B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bi
usz xcsz = shProd(SH(x), axis, RNK(x));
usz ocsz = shProd(csh, 0, cr);
// printf("narrowWidenedBitArr ia=%d axis=%d cr=%d ocsz=%d xcsz=%d\n", IA(x), axis, cr, ocsz, xcsz);
assert((xcsz&7) == 0 && ocsz<xcsz);
assert((xcsz&7) == 0 && ocsz<xcsz && ocsz!=0);
if (xcsz==ocsz) {
if (RNK(x)-axis == cr && eqShPart(SH(x)+axis, csh, cr)) return x;
Arr* r = cpyWithShape(x);