Separate code path for cxsz=1, fix out of bounds read for RNK(x)==0

This commit is contained in:
Andrea Piseri 2024-05-18 21:45:18 +02:00
parent 7f28308e44
commit 06808414da

View File

@ -372,7 +372,7 @@ B select_c2(B t, B w, B x) {
extern INIT_GLOBAL u8 reuseElType[t_COUNT]; extern INIT_GLOBAL u8 reuseElType[t_COUNT];
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcsz) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
#if CHECK_VALID #if CHECK_VALID
TALLOC(bool, set, xl); TALLOC(bool, set, xl);
bool sparse = wia < xl/64; bool sparse = wia < xl/64;
@ -401,19 +401,30 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
f64* wp = f64any_ptr(w); f64* wp = f64any_ptr(w);
SPARSE_INIT((i64)wp[i]) SPARSE_INIT((i64)wp[i])
MAKE_MUT(r, xl * xcia); MAKE_MUT(r, xl * xcsz);
mut_init_copy(r, x, re); mut_init_copy(r, x, re);
NOGC_E; NOGC_E;
MUTG_INIT(r); SGet(rep) MUTG_INIT(r); SGet(rep)
for (usz i = 0; i < wia; i++) { if (xcsz==1) {
READ_W(cw, i); for (usz i = 0; i < wia; i++) {
for (usz j = 0; j < xcia; j++) { READ_W(cw, i);
B cn = Get(rep, i * xcia + j); B cn = Get(rep, i);
EQ(!equal(mut_getU(r, cw * xcia + j), cn)); EQ(!equal(mut_getU(r, cw), cn));
mut_rm(r, cw * xcia + j); mut_rm(r, cw);
mut_setG(r, cw * xcia + j, cn); mut_setG(r, cw, cn);
DONE_CW;
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcsz; j++) {
B cn = Get(rep, i * xcsz + j);
EQ(!equal(mut_getU(r, cw * xcsz + j), cn));
mut_rm(r, cw * xcsz + j);
mut_setG(r, cw * xcsz + j, cn);
}
DONE_CW;
} }
DONE_CW;
} }
ra = mut_fp(r); ra = mut_fp(r);
goto dec_ret_ra; goto dec_ret_ra;
@ -437,14 +448,24 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
TyArr* na = toBitArr(rep); rep = taga(na); TyArr* na = toBitArr(rep); rep = taga(na);
u64* np = bitarrv_ptr(na); u64* np = bitarrv_ptr(na);
u64* rp = (void*)((TyArr*)ra)->a; u64* rp = (void*)((TyArr*)ra)->a;
for (usz i = 0; i < wia; i++) { if (xcsz==1) {
READ_W(cw, i); for (usz i = 0; i < wia; i++) {
for (usz j = 0; j < xcia; j++) { READ_W(cw, i);
bool cn = bitp_get(np, i * xcia + j); bool cn = bitp_get(np, i);
EQ(cn != bitp_get(rp, cw * xcia + j)); EQ(cn != bitp_get(rp, cw));
bitp_set(rp, cw * xcia + j, cn); bitp_set(rp, cw, cn);
DONE_CW;
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcsz; j++) {
bool cn = bitp_get(np, i * xcsz + j);
EQ(cn != bitp_get(rp, cw * xcsz + j));
bitp_set(rp, cw * xcsz + j, cn);
}
DONE_CW;
} }
DONE_CW;
} }
goto dec_ret_ra; goto dec_ret_ra;
} }
@ -452,32 +473,54 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
ra = reuse? a(REUSE(x)) : cpyHArr(x); ra = reuse? a(REUSE(x)) : cpyHArr(x);
B* rp = harrP_parts((HArr*)ra).a; B* rp = harrP_parts((HArr*)ra).a;
SGet(rep) SGet(rep)
for (usz i = 0; i < wia; i++) { if (xcsz==1)
READ_W(cw, i); {
for (usz j = 0; j < xcia; j++) { for (usz i = 0; i < wia; i++) {
B cn = Get(rep, i * xcia + j); READ_W(cw, i);
EQ(!equal(cn,rp[cw * xcia + j])); B cn = Get(rep, i);
dec(rp[cw * xcia + j]); EQ(!equal(cn,rp[cw]));
rp[cw * xcia + j] = cn; dec(rp[cw]);
rp[cw] = cn;
DONE_CW;
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
for (usz j = 0; j < xcsz; j++) {
B cn = Get(rep, i * xcsz + j);
EQ(!equal(cn,rp[cw * xcsz + j]));
dec(rp[cw * xcsz + j]);
rp[cw * xcsz + j] = cn;
}
DONE_CW;
} }
DONE_CW;
} }
goto dec_ret_ra; goto dec_ret_ra;
} }
} }
#define IMPL(T) do { \ #define IMPL(T) do { \
T* rp = (void*)((TyArr*)ra)->a; \ T* rp = (void*)((TyArr*)ra)->a; \
T* np = tyany_ptr(rep); \ T* np = tyany_ptr(rep); \
for (usz i = 0; i < wia; i++) { \ if (xcsz==1) { \
READ_W(cw, i); \ for (usz i = 0; i < wia; i++) { \
for (usz j = 0; j < xcia; j++) { \ READ_W(cw, i); \
T cn = np[i * xcia + j]; \ T cn = np[i]; \
EQ(cn != rp[cw * xcia + j]); \ EQ(cn != rp[cw]); \
rp[cw * xcia + j] = cn; \ rp[cw] = cn; \
} \ DONE_CW; \
DONE_CW; \ } \
} \ } else { \
for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \
for (usz j = 0; j < xcsz; j++) { \
T cn = np[i * xcsz + j]; \
EQ(cn != rp[cw * xcsz + j]); \
rp[cw * xcsz + j] = cn; \
} \
DONE_CW; \
} \
} \
goto dec_ret_ra; \ goto dec_ret_ra; \
} while(0) } while(0)
@ -501,7 +544,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep
} }
B select_ucw(B t, B o, B w, B x) { B select_ucw(B t, B o, B w, B x) {
if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); } if (isAtm(x) || RNK(x)==0 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
usz xia = IA(x); usz xia = IA(x);
usz wia = IA(w); usz wia = IA(w);
u8 we = TI(w,elType); u8 we = TI(w,elType);
@ -522,6 +565,6 @@ B select_ucw(B t, B o, B w, B x) {
usz rr = RNK(rep); usz rr = RNK(rep);
bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1); bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1);
if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep); if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep);
usz ia = shProd(SH(x), 1, RNK(x)); usz xcsz = arr_csz(x);
return select_replace(U'', w, x, rep, wia, SH(x)[0], ia); return select_replace(U'', w, x, rep, wia, SH(x)[0], xcsz);
} }