native path in select_ucw for high rank 𝕩

`𝔽⌾(a⊸⊏)𝕩` now does not need to go through the self-hosted runtime if
`1<=𝕩`. Instead the `select_replace` helper function is parametrized
over the length of `𝕩` (`xl`) and item amount of the cell of `𝕩` (`xcia`).

- The `EQ` macro is modified to not immediately mark the cell as
  populated, so that multiple replacements can be done on the cell
  on the first assignment to it.

- The `DONE_CW` macro is invoked to mark the current cell as populated
  when every element in it has been assigned.

- A loop over the cell contents is introduced to copy the elements in
  `𝕩`. This should be fine as it is an easily predictable jump,
  but a performance regression is possible and a separate code path
  could be introduced in the future.

The change introduces more extensive checking on the shape of `𝔽`'s
result, as for high rank `𝕩` the requirement should be `(≢𝔽a⊏𝕩)≡(≢a)∾1↓≢𝕩`.

The old behaviour of `select_replace` is recovered by passing `xl=xia`
and `xcia=1` in the implementation of `pick_ucw`.
This commit is contained in:
Andrea Piseri 2024-05-18 12:12:17 +02:00
parent 4f898f38d2
commit 7f28308e44
2 changed files with 55 additions and 35 deletions

View File

@ -372,24 +372,26 @@ B select_c2(B t, B w, B x) {
extern INIT_GLOBAL u8 reuseElType[t_COUNT]; extern INIT_GLOBAL u8 reuseElType[t_COUNT];
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia) { // rep⌾(w⊏⥊) x, assumes w is a typed (elNum) list of valid indices, only el_f64 if strictly necessary
#if CHECK_VALID #if CHECK_VALID
TALLOC(bool, set, xia); TALLOC(bool, set, xl);
bool sparse = wia < xia/64; bool sparse = wia < xl/64;
if (!sparse) for (i64 i = 0; i < xia; i++) set[i] = false; if (!sparse) for (i64 i = 0; i < xl; i++) set[i] = false;
#define SPARSE_INIT(WI) \ #define SPARSE_INIT(WI) \
if (sparse) for (usz i = 0; i < wia; i++) { \ if (sparse) for (usz i = 0; i < wia; i++) { \
i64 cw = WI; if (RARE(cw<0)) cw+= (i64)xia; set[cw] = false; \ i64 cw = WI; if (RARE(cw<0)) cw+= (i64)xl; set[cw] = false; \
} }
#define EQ(F) if (set[cw] && (F)) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); set[cw] = true; #define EQ(F) if (set[cw] && (F)) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr);
#define DONE_CW set[cw] = true;
#define FREE_CHECK TFREE(set) #define FREE_CHECK TFREE(set)
#else #else
#define SPARSE_INIT(GET) #define SPARSE_INIT(GET)
#define EQ(F) #define EQ(F)
#define DONE_CW
#define FREE_CHECK #define FREE_CHECK
#endif #endif
#define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)xia #define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)xl
u8 we = TI(w,elType); assert(elNum(we)); u8 we = TI(w,elType); assert(elNum(we));
u8 xe = TI(x,elType); u8 xe = TI(x,elType);
u8 re = el_or(xe, TI(rep,elType)); u8 re = el_or(xe, TI(rep,elType));
@ -399,16 +401,19 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
f64* wp = f64any_ptr(w); f64* wp = f64any_ptr(w);
SPARSE_INIT((i64)wp[i]) SPARSE_INIT((i64)wp[i])
MAKE_MUT(r, xia); MAKE_MUT(r, xl * xcia);
mut_init_copy(r, x, re); mut_init_copy(r, x, re);
NOGC_E; NOGC_E;
MUTG_INIT(r); SGet(rep) MUTG_INIT(r); SGet(rep)
for (usz i = 0; i < wia; i++) { for (usz i = 0; i < wia; i++) {
READ_W(cw, i); READ_W(cw, i);
B cn = Get(rep, i); for (usz j = 0; j < xcia; j++) {
EQ(!equal(mut_getU(r, cw), cn)); B cn = Get(rep, i * xcia + j);
mut_rm(r, cw); EQ(!equal(mut_getU(r, cw * xcia + j), cn));
mut_setG(r, cw, cn); mut_rm(r, cw * xcia + j);
mut_setG(r, cw * xcia + j, cn);
}
DONE_CW;
} }
ra = mut_fp(r); ra = mut_fp(r);
goto dec_ret_ra; goto dec_ret_ra;
@ -419,7 +424,7 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
i32* wp = i32any_ptr(w); i32* wp = i32any_ptr(w);
SPARSE_INIT(wp[i]) SPARSE_INIT(wp[i])
bool reuse = reusable(x) && re==reuseElType[TY(x)]; bool reuse = reusable(x) && re==reuseElType[TY(x)];
SLOWIF(!reuse && xia>100 && wia<xia/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x); SLOWIF(!reuse && xl>100 && wia<xl/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x);
switch (re) { default: UD; switch (re) { default: UD;
case el_i8: rep = toI8Any(rep); ra = reuse? a(REUSE(x)) : cpyI8Arr(x); goto do_u8; case el_i8: rep = toI8Any(rep); ra = reuse? a(REUSE(x)) : cpyI8Arr(x); goto do_u8;
case el_c8: rep = toC8Any(rep); ra = reuse? a(REUSE(x)) : cpyC8Arr(x); goto do_u8; case el_c8: rep = toC8Any(rep); ra = reuse? a(REUSE(x)) : cpyC8Arr(x); goto do_u8;
@ -434,9 +439,12 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
u64* rp = (void*)((TyArr*)ra)->a; u64* rp = (void*)((TyArr*)ra)->a;
for (usz i = 0; i < wia; i++) { for (usz i = 0; i < wia; i++) {
READ_W(cw, i); READ_W(cw, i);
bool cn = bitp_get(np, i); for (usz j = 0; j < xcia; j++) {
EQ(cn != bitp_get(rp, cw)); bool cn = bitp_get(np, i * xcia + j);
bitp_set(rp, cw, cn); EQ(cn != bitp_get(rp, cw * xcia + j));
bitp_set(rp, cw * xcia + j, cn);
}
DONE_CW;
} }
goto dec_ret_ra; goto dec_ret_ra;
} }
@ -446,24 +454,30 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
SGet(rep) SGet(rep)
for (usz i = 0; i < wia; i++) { for (usz i = 0; i < wia; i++) {
READ_W(cw, i); READ_W(cw, i);
B cn = Get(rep, i); for (usz j = 0; j < xcia; j++) {
EQ(!equal(cn,rp[cw])); B cn = Get(rep, i * xcia + j);
dec(rp[cw]); EQ(!equal(cn,rp[cw * xcia + j]));
rp[cw] = cn; dec(rp[cw * xcia + j]);
rp[cw * xcia + j] = cn;
}
DONE_CW;
} }
goto dec_ret_ra; goto dec_ret_ra;
} }
} }
#define IMPL(T) do { \ #define IMPL(T) do { \
T* rp = (void*)((TyArr*)ra)->a; \ T* rp = (void*)((TyArr*)ra)->a; \
T* np = tyany_ptr(rep); \ T* np = tyany_ptr(rep); \
for (usz i = 0; i < wia; i++) { \ for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \ READ_W(cw, i); \
T cn = np[i]; \ for (usz j = 0; j < xcia; j++) { \
EQ(cn != rp[cw]); \ T cn = np[i * xcia + j]; \
rp[cw] = cn; \ EQ(cn != rp[cw * xcia + j]); \
} \ rp[cw * xcia + j] = cn; \
} \
DONE_CW; \
} \
goto dec_ret_ra; \ goto dec_ret_ra; \
} while(0) } while(0)
@ -482,11 +496,12 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊
#undef SPARSE_INIT #undef SPARSE_INIT
#undef EQ #undef EQ
#undef DONE_CW
#undef FREE_CHECK #undef FREE_CHECK
} }
B select_ucw(B t, B o, B w, B x) { B select_ucw(B t, B o, B w, B x) {
if (isAtm(x) || RNK(x)!=1 || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); } if (isAtm(x) || isAtm(w)) { def: return def_fn_ucw(t, o, w, x); }
usz xia = IA(x); usz xia = IA(x);
usz wia = IA(w); usz wia = IA(w);
u8 we = TI(w,elType); u8 we = TI(w,elType);
@ -502,6 +517,11 @@ B select_ucw(B t, B o, B w, B x) {
} else { } else {
rep = c1(o, C2(select, incG(w), incG(x))); rep = c1(o, C2(select, incG(w), incG(x)));
} }
if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep); usz xr = RNK(x);
return select_replace(U'', w, x, rep, wia, xia); usz wr = RNK(w);
usz rr = RNK(rep);
bool ok = !isAtm(rep) && xr+wr==rr+1 && eqShPart(SH(w),SH(rep),wr) && eqShPart(SH(x)+1,SH(rep)+wr,xr-1);
if (!ok) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (%H ≡ shape of a, %2H = shape of ⊏𝕩, %H ≡ shape of result of 𝔽)", w, xr-1, SH(x)+1, rep);
usz ia = shProd(SH(x), 1, RNK(x));
return select_replace(U'', w, x, rep, wia, SH(x)[0], ia);
} }

View File

@ -1302,7 +1302,7 @@ B pick_uc1(B t, B o, B x) {
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia); B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xl, usz xcia);
B select_ucw(B t, B o, B w, B x); B select_ucw(B t, B o, B w, B x);
B select_c2(B,B,B); B select_c2(B,B,B);
B pick_ucw(B t, B o, B w, B x) { B pick_ucw(B t, B o, B w, B x) {
@ -1329,7 +1329,7 @@ B pick_ucw(B t, B o, B w, B x) {
w = num_squeeze(mut_fcd(r, w)); w = num_squeeze(mut_fcd(r, w));
B rep = isArr(o)? incG(o) : c1(o, C2(select, incG(w), C1(shape, incG(x)))); B rep = isArr(o)? incG(o) : c1(o, C2(select, incG(w), C1(shape, incG(x))));
if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊑)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep); if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊑)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep);
return select_replace(U'', w, x, rep, wia, xia); return select_replace(U'', w, x, rep, wia, xia, 1);
} }
decG(w); decG(w);
} }