native F⌾(list⊸⊑) & ⌾(listOfLists⊸⊑)

This commit is contained in:
dzaima 2023-05-18 01:30:16 +03:00
parent d310669ae8
commit 0572fcc5b2
2 changed files with 97 additions and 33 deletions

View File

@ -362,17 +362,7 @@ B select_c2(B t, B w, B x) {
B select_ucw(B t, B o, B w, B x) {
if (isAtm(x) || RNK(x)!=1 || isAtm(w)) return def_fn_ucw(t, o, w, x);
usz xia = IA(x);
usz wia = IA(w);
SGetU(w)
if (TI(w,elType)!=el_i32) for (usz i = 0; i < wia; i++) if (!q_i64(GetU(w,i))) return def_fn_ucw(t, o, w, x);
B arg = select_c2(t, incG(w), incG(x));
B rep = c1(o, arg);
if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep);
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia) { // rep⌾(w⊏⥊) x
#if CHECK_VALID
TALLOC(bool, set, xia);
bool sparse = wia < xia/64;
@ -381,9 +371,9 @@ B select_ucw(B t, B o, B w, B x) {
if (sparse) for (usz i = 0; i < wia; i++) { \
i64 cw = WI; if (RARE(cw<0)) cw+= (i64)xia; set[cw] = false; \
}
#define EQ(F) if (set[cw] && (F)) thrM("𝔽⌾(a⊸⊏): Incompatible result elements"); set[cw] = true;
#define EQ(F) if (set[cw] && (F)) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); set[cw] = true;
#define FREE_CHECK TFREE(set)
SLOWIF(xia>100 && wia<xia/20) SLOW2("⌾(𝕨⊸⊏)𝕩 because CHECK_VALID", w, x);
SLOWIF(xia>100 && wia<xia/20) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because CHECK_VALID", w, x);
#else
#define SPARSE_INIT(GET)
#define EQ(F)
@ -393,7 +383,7 @@ B select_ucw(B t, B o, B w, B x) {
u8 we = TI(w,elType);
u8 xe = TI(x,elType);
u8 re = TI(rep,elType);
SLOWIF(!reusable(x) && xia>100 && wia<xia/50) SLOW2("⌾(𝕨⊸⊏)𝕩 because not reusable", w, x);
SLOWIF(!reusable(x) && xia>100 && wia<xia/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x);
if (elInt(we)) {
w = toI32Any(w);
i32* wp = i32any_ptr(w);
@ -490,7 +480,7 @@ B select_ucw(B t, B o, B w, B x) {
MAKE_MUT_INIT(r, xia, el_or(xe, re));
MUTG_INIT(r);
mut_copyG(r, 0, x, 0, xia);
SGet(rep)
SGet(rep) SGetU(w)
SPARSE_INIT(o2i64G(GetU(w, i)))
for (usz i = 0; i < wia; i++) {
i64 cw = o2i64G(GetU(w, i)); if (RARE(cw<0)) cw+= (i64)xia;
@ -505,3 +495,14 @@ B select_ucw(B t, B o, B w, B x) {
#undef EQ
#undef FREE_CHECK
}
B select_ucw(B t, B o, B w, B x) {
if (isAtm(x) || RNK(x)!=1 || isAtm(w)) return def_fn_ucw(t, o, w, x);
usz xia = IA(x);
usz wia = IA(w);
SGetU(w)
if (!elInt(TI(w,elType))) for (usz i = 0; i < wia; i++) if (!q_i64(GetU(w,i))) return def_fn_ucw(t, o, w, x);
B rep = c1(o, C2(select, incG(w), incG(x)));
if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep);
return select_replace(U'', w, x, rep, wia, xia);
}

View File

@ -339,33 +339,44 @@ static NOINLINE void checkIndexList(B w, ur xr) {
thrF("⊑: Leaf array in 𝕨 too large (has shape %H)", w);
}
}
// calculate index
#define PICK_IDX(RES, GET, IA, OOB) \
usz RES = 0; \
for (usz i=0, ia_=(IA); i < ia_; i++) { \
c = c*xsh[i] + WRAP(GET, xsh[i], OOB); \
}
static i64 pick_convFloat(f64 f) {
if (LIKELY(q_fi64(f))) return (i64)f;
thrM("⊑: 𝕨 contained a non-integer");
}
static B recPick(B w, B x) { // doesn't consume
assert(isArr(w) && isArr(x));
usz ia = IA(w);
ur xr = RNK(x);
usz* xsh = SH(x);
switch(TI(w,elType)) { default: UD;
case el_i8: { i8* wp = i8any_ptr (w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; usz c=0; for (usz i = 0; i < ia; i++) { c = c*xsh[i] + WRAP(wp[i], xsh[i], goto oob); }; return IGet(x,c); }
case el_i16: { i16* wp = i16any_ptr(w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; usz c=0; for (usz i = 0; i < ia; i++) { c = c*xsh[i] + WRAP(wp[i], xsh[i], goto oob); }; return IGet(x,c); }
case el_i32: { i32* wp = i32any_ptr(w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; usz c=0; for (usz i = 0; i < ia; i++) { c = c*xsh[i] + WRAP(wp[i], xsh[i], goto oob); }; return IGet(x,c); }
case el_f64: { f64* wp = f64any_ptr(w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; usz c=0; for (usz i = 0; i < ia; i++) { i64 ws = (i64)wp[i]; if (wp[i]!=ws) thrM(ws==I64_MIN? "⊑: 𝕨 contained value too large" : "⊑: 𝕨 contained a non-integer");
c = c*xsh[i] + WRAP(ws, xsh[i], goto oob); }; return IGet(x,c); }
case el_i8: { i8* wp = i8any_ptr (w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; PICK_IDX(c, wp[i], ia, goto oob) return IGet(x,c); }
case el_i16: { i16* wp = i16any_ptr(w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; PICK_IDX(c, wp[i], ia, goto oob) return IGet(x,c); }
case el_i32: { i32* wp = i32any_ptr(w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; PICK_IDX(c, wp[i], ia, goto oob) return IGet(x,c); }
case el_f64: { f64* wp = f64any_ptr(w); if(RNK(w)!=1)goto wrr; if (ia!=xr)goto wrl; PICK_IDX(c, pick_convFloat(wp[i]), ia, goto oob) return IGet(x,c); }
case el_c8: case el_c16: case el_c32: case el_bit:
case el_B: {
if (ia==0) {
if (xr!=0) thrM("⊑: Empty array in 𝕨 must correspond to unit in 𝕩");
if (xr!=0) thrM("⊑: 𝕩 must be a unit if 𝕨 contains an empty array");
return IGet(x,0);
}
SGetU(w)
if (isNum(GetU(w,0))) {
if(RNK(w)!=1) goto wrr;
if (ia!=xr) goto wrl;
usz c=0;
for (usz i = 0; i < ia; i++) {
PICK_IDX(c, ({
B cw = GetU(w,i);
if (!isNum(cw)) thrM("⊑: 𝕨 contained list with mixed-type elements");
c = c*xsh[i] + WRAP(o2i64(cw), xsh[i], goto oob);
}
o2i64(cw);
}), ia, goto oob);
return IGet(x,c);
} else {
M_HARR(r, ia);
@ -1216,7 +1227,31 @@ B reverse_c2(B t, B w, B x) {
return withFill(mut_fcd(r, x), xf);
}
static B replaceOne(B fn, usz pos, B x, usz xia) {
static usz pick_oneIndex(B w, usz xr, usz* xsh) { // throws if guaranteed bad; returns USZ_MAX if not a plain index
assert(xr!=0);
if (RARE(isAtm(w) || IA(w)!=xr || RNK(w)!=1)) return USZ_MAX;
switch(TI(w,elType)) { default: UD;
case el_i8: { i8* wp = i8any_ptr (w); PICK_IDX(c, wp[i], xr, goto oob) return c; }
case el_i16: { i16* wp = i16any_ptr(w); PICK_IDX(c, wp[i], xr, goto oob) return c; }
case el_i32: { i32* wp = i32any_ptr(w); PICK_IDX(c, wp[i], xr, goto oob) return c; }
case el_f64: { f64* wp = f64any_ptr(w); PICK_IDX(c, pick_convFloat(wp[i]), xr, goto oob) return c; }
case el_c8: case el_c16: case el_c32: case el_bit:
case el_B: {
SGetU(w)
if (!isNum(GetU(w,0))) return USZ_MAX;
PICK_IDX(c, ({
B cw = GetU(w,i);
if (!isNum(cw)) thrM("⊑: 𝕨 contained list with mixed-type elements");
o2i64(cw);
}), xr, goto oob);
return c;
}
}
oob: checkIndexList(w, xr); thrF("⊑: Indexing out-of-bounds (index %B in array of shape %2H)", w, xr, xsh);
}
static B pick_replaceOne(B fn, usz pos, B x, usz xia) {
if (TI(x,elType)==el_B) {
B* xp;
if (TY(x)==t_harr || TY(x)==t_hslice) {
@ -1260,16 +1295,46 @@ static B replaceOne(B fn, usz pos, B x, usz xia) {
return qWithFill(mut_fcd(r, x), xf);
}
B pick_uc1(B t, B o, B x) { // TODO do in-place like pick_ucw; maybe just call it?
B pick_uc1(B t, B o, B x) {
if (isAtm(x) || IA(x)==0) return def_fn_uc1(t, o, x);
return replaceOne(o, 0, x, IA(x));
return pick_replaceOne(o, 0, x, IA(x));
}
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz xia);
B select_ucw(B t, B o, B w, B x);
B select_c2(B,B,B);
B pick_ucw(B t, B o, B w, B x) {
if (isArr(w) || isAtm(x) || RNK(x)!=1) return def_fn_ucw(t, o, w, x);
if (RARE(isAtm(x))) def: return def_fn_ucw(t, o, w, x);
usz xia = IA(x);
usz wi = WRAP(o2i64(w), xia, thrF("𝔽⌾(n⊸⊑)𝕩: reading out-of-bounds (n≡%R, %s≡≠𝕩)", w, xia));
return replaceOne(o, wi, x, xia);
usz pos;
if (isAtm(w)) {
if (RARE(RNK(x)!=1)) goto def;
pos = WRAP(o2i64(w), xia, thrF("𝔽⌾(n⊸⊑)𝕩: reading out-of-bounds (n≡%R, %s≡≠𝕩)", w, xia));
} else {
usz wia = IA(w);
usz* xsh = SH(x);
ur xr = RNK(x);
if (xr==0) goto def;
pos = pick_oneIndex(w, xr, xsh);
if (pos == USZ_MAX) {
MAKE_MUT_INIT(r, wia, xia<I8_MAX? el_i8 : xia<I16_MAX? el_i16 : xia<I32_MAX? el_i32 : el_f64); MUTG_INIT(r);
SGetU(w)
for (usz i = 0; i < wia; i++) {
usz c = pick_oneIndex(GetU(w,i), xr, xsh);
if (RARE(c==USZ_MAX)) { mut_pfree(r, i); return def_fn_ucw(t, o, w, x); }
mut_setG(r, i, m_usz(c));
}
decG(w);
w = mut_fv(r);
B rep = c1(o, C2(select, incG(w), C1(shape, incG(x))));
if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊑)𝕩: 𝔽 must return an array with the same shape as its input (expected %H, got %H)", w, rep);
return select_replace(U'', w, x, rep, wia, xia);
}
decG(w);
}
return pick_replaceOne(o, pos, x, xia);
}
static B takedrop_ucw(i64 wi, B o, u64 am, B x, ux xr) {
@ -1338,8 +1403,6 @@ B shape_uc1(B t, B o, B x) {
return truncReshape(shape_uc1_t(c1(o, shape_c1(t, x)), xia), xia, xia, xr, sh);
}
B select_ucw(B t, B o, B w, B x);
B reverse_uc1(B t, B o, B x) { return reverse_c1(m_f64(0), c1(o, reverse_c1(t, x))); }
B reverse_ix(B t, B w, B x) {