Separate code path for cxsz=1, fix out of bounds read for RNK(x)==0

This commit is contained in:
Andrea Piseri 2024-05-18 21:45:18 +02:00
parent 7f28308e44
commit 06808414da

View File

@ -372,7 +372,7 @@ B select_c2(B t, B w, B x) {
extern INIT_GLOBAL u8 reuseElType[t_COUNT];
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcsz) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
#if CHECK_VALID
TALLOC(bool, set, xl);
bool sparse = wia < xl/64;
@ -401,19 +401,30 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
f64* wp = f64any_ptr(w);
SPARSE_INIT((i64)wp[i])
MAKE_MUT(r, xl * xcia);
MAKE_MUT(r, xl * xcsz);
mut_init_copy(r, x, re);
NOGC_E;
MUTG_INIT(r); SGet(rep)
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcia; j++) {
B cn = Get(rep, i * xcia + j);
EQ(!equal(mut_getU(r, cw * xcia + j), cn));
mut_rm(r, cw * xcia + j);
mut_setG(r, cw * xcia + j, cn);
if (xcsz==1) {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
B cn = Get(rep, i);
EQ(!equal(mut_getU(r, cw), cn));
mut_rm(r, cw);
mut_setG(r, cw, cn);
DONE_CW;
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcsz; j++) {
B cn = Get(rep, i * xcsz + j);
EQ(!equal(mut_getU(r, cw * xcsz + j), cn));
mut_rm(r, cw * xcsz + j);
mut_setG(r, cw * xcsz + j, cn);
}
DONE_CW;
}
DONE_CW;
}
ra = mut_fp(r);
goto dec_ret_ra;
@ -437,14 +448,24 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
TyArr* na = toBitArr(rep); rep = taga(na);
u64* np = bitarrv_ptr(na);
u64* rp = (void*)((TyArr*)ra)->a;
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcia; j++) {
bool cn = bitp_get(np, i * xcia + j);
EQ(cn != bitp_get(rp, cw * xcia + j));
bitp_set(rp, cw * xcia + j, cn);
if (xcsz==1) {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
bool cn = bitp_get(np, i);
EQ(cn != bitp_get(rp, cw));
bitp_set(rp, cw, cn);
DONE_CW;
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcsz; j++) {
bool cn = bitp_get(np, i * xcsz + j);
EQ(cn != bitp_get(rp, cw * xcsz + j));
bitp_set(rp, cw * xcsz + j, cn);
}
DONE_CW;
}
DONE_CW;
}
goto dec_ret_ra;
}
@ -452,32 +473,54 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
ra = reuse? a(REUSE(x)) : cpyHArr(x);
B* rp = harrP_parts((HArr*)ra).a;
SGet(rep)
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcia; j++) {
B cn = Get(rep, i * xcia + j);
EQ(!equal(cn,rp[cw * xcia + j]));
dec(rp[cw * xcia + j]);
rp[cw * xcia + j] = cn;
if (xcsz==1)
{
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
B cn = Get(rep, i);
EQ(!equal(cn,rp[cw]));
dec(rp[cw]);
rp[cw] = cn;
DONE_CW;
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcsz; j++) {
B cn = Get(rep, i * xcsz + j);
EQ(!equal(cn,rp[cw * xcsz + j]));
dec(rp[cw * xcsz + j]);
rp[cw * xcsz + j] = cn;
}
DONE_CW;
}
DONE_CW;
}
goto dec_ret_ra;
}
}
#define IMPL(T) do { \
T* rp = (void*)((TyArr*)ra)->a; \
T* np = tyany_ptr(rep); \
for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \
for (usz j = 0; j < xcia; j++) { \
T cn = np[i * xcia + j]; \
EQ(cn != rp[cw * xcia + j]); \
rp[cw * xcia + j] = cn; \
} \
DONE_CW; \
} \
#define IMPL(T) do { \
T* rp = (void*)((TyArr*)ra)->a; \
T* np = tyany_ptr(rep); \
if (xcsz==1) { \
for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \
T cn = np[i]; \
EQ(cn != rp[cw]); \
rp[cw] = cn; \
DONE_CW; \
} \
} else { \
for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \
for (usz j = 0; j < xcsz; j++) { \
T cn = np[i * xcsz + j]; \
EQ(cn != rp[cw * xcsz + j]); \
rp[cw * xcsz + j] = cn; \
} \
DONE_CW; \
} \
} \
goto dec_ret_ra; \
} while(0)
@ -501,7 +544,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
}
B select_ucw(B t, B o, B w, B x) {
if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
if (isAtm(x) || RNK(x)==0 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
usz xia = IA(x);
usz wia = IA(w);
u8 we = TI(w,elType);
@ -522,6 +565,6 @@ B select_ucw(B t, B o, B w, B x) {
usz rr = RNK(rep);
bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1);
if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep);
usz ia = shProd(SH(x), 1, RNK(x));
return select_replace(U'', w, x, rep, wia, SH(x)[0], ia);
usz xcsz = arr_csz(x);
return select_replace(U'', w, x, rep, wia, SH(x)[0], xcsz);
}