#include "../core.h" #include "../utils/talloc.h" #include "../utils/mut.h" #include "../builtins.h" // #if SINGELI // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored "-Wunused-variable" // #include "../singeli/gen/select.c" // #pragma GCC diagnostic pop // #endif extern B rt_select; B select_c1(B t, B x) { if (isAtm(x)) thrM("⊏: Argument cannot be an atom"); ur xr = RNK(x); if (xr==0) thrM("⊏: Argument cannot be rank 0"); if (SH(x)[0]==0) thrF("⊏: Argument shape cannot start with 0 (%H ≡ ≢𝕩)", x); usz ia = shProd(SH(x), 1, xr); Arr* r = TI(x,slice)(incG(x), 0, ia); usz* sh = arr_shAlloc(r, xr-1); if (sh) shcpy(sh, SH(x)+1, xr-1); decG(x); return taga(r); } B select_c2(B t, B w, B x) { if (isAtm(x)) thrM("⊏: 𝕩 cannot be an atom"); ur xr = RNK(x); if (isAtm(w)) { if (xr==0) thrM("⊏: 𝕩 cannot be a unit"); usz csz = arr_csz(x); usz cam = SH(x)[0]; usz wi = WRAP(o2i64(w), cam, thrF("⊏: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, cam)); Arr* r = TI(x,slice)(incG(x), wi*csz, csz); usz* sh = arr_shAlloc(r, xr-1); if (sh) shcpy(sh, SH(x)+1, xr-1); decG(x); return taga(r); } B xf = getFillQ(x); SGet(x) usz wia = IA(w); B r; if (xr==1) { if (wia==0) { decG(x); if (RNK(w)==1) { if (isNum(xf)) { r = emptyIVec(); goto ret; } if (isC32(xf)) { r = emptyCVec(); goto ret; } } Arr* ra; if (isNum(xf) || isC32(xf)) { ra = m_arr(sizeof(TyArr), isNum(xf)? t_i8arr : t_c8arr, 0); } else { ra = m_fillarrp(0); fillarr_setFill(ra, xf); } arr_shCopy(ra, w); r = taga(ra); ret: decG(w); return r; } usz xia = IA(x); if (xia==0) goto base; // can't just error immediately because depth 2 𝕨 u8 xe = TI(x,elType); u8 we = TI(w,elType); #if SINGELI // if (we==el_i8 && xe==el_i32) { i32* rp; r = m_i32arrc(&rp, w); if (!avx2_select_i8_32 ((u8*)i8any_ptr (w), (u8*)i32any_ptr(x), (u8*)rp, wia, xia)) thrM("⊏: Indexing out-of-bounds"); goto dec_ret; } // if (we==el_i16 && xe==el_i32) { i32* rp; r = m_i32arrc(&rp, w); if (!avx2_select_i16_32((u8*)i16any_ptr(w), (u8*)i32any_ptr(x), (u8*)rp, wia, xia)) thrM("⊏: Indexing out-of-bounds"); goto dec_ret; } // if (we==el_i32 && xe==el_i8 ) { i8* rp; r = m_i8arrc (&rp, w); if (!avx2_select_i32_8 ((u8*)i32any_ptr(w), (u8*)i8any_ptr (x), (u8*)rp, wia, xia)) thrM("⊏: Indexing out-of-bounds"); goto dec_ret; } // if (we==el_i32 && xe==el_i32) { i32* rp; r = m_i32arrc(&rp, w); if (!avx2_select_i32_32((u8*)i32any_ptr(w), (u8*)i32any_ptr(x), (u8*)rp, wia, xia)) thrM("⊏: Indexing out-of-bounds"); goto dec_ret; } // if (we==el_i32 && xe==el_f64) { f64* rp; r = m_f64arrc(&rp, w); if (!avx2_select_i32_64((u8*)i32any_ptr(w), (u8*)f64any_ptr(x), (u8*)rp, wia, xia)) thrM("⊏: Indexing out-of-bounds"); goto dec_ret; } #endif #define CASE(S,E) case S: for (usz i=i0; i> 63; \ if (i%64 == 0) { rp[i/64]=b; if (!i) break; } \ } \ goto dec_ret; \ } \ if (xe!=el_B) { \ usz xw = elWidth(xe); \ void* rp = m_tyarrc(&r, xw, w, el2t(xe)); \ void* xp = tyany_ptr(x); \ if (sizeof(W) >= 4) { \ switch(xw) { default:UD; CASEW(1,u8); CASEW(2,u16); CASEW(4,u32); CASEW(8,f64); } \ } else { \ W* wt = NULL; \ for (usz bl=(1<<14)/sizeof(W), i0=0, i1=0; i0wia) i1=wia; \ W min=wp[i0], max=min; for (usz i=i0+1; imax) max=e; if (e=(i64)xia) thrF("⊏: Indexing out-of-bounds (%i∊𝕨, %s≡≠𝕩)", max, xia); \ W* ip=wp; usz off=xia; \ if (max>=0) { off=0; if (RARE(min<0)) { \ if (!wt) {wt=TALLOCP(W,i1-i0);} ip=wt-i0; \ for (usz i=i0; i=256 && wia/4>=xia && we!=el_bit) { return taga(cpyBitArr(select_c2(m_f64(0), w, taga(cpyI8Arr(x))))); } if (we==el_bit) { SGetU(x) B x0 = GetU(x, 0); B x1; if (xia<2) { u64* wp=bitarr_ptr(w); usz i; for (i=0; iUR_MAX) thrF("⊏: Result rank too large (%i≡=𝕨, %i≡=𝕩)", wr, xr); usz csz = arr_csz(x); usz cam = SH(x)[0]; MAKE_MUT(r, wia*csz); mut_init(r, TI(x,elType)); MUTG_INIT(r); for (usz i = 0; i < wia; i++) { B cw = GetU(w, i); // assumed number from previous squeeze usz c = WRAP(o2i64(cw), cam, { mut_pfree(r, i*csz); thrF("⊏: Indexing out-of-bounds (%R∊𝕨, %s≡≠𝕩)", cw, cam); }); mut_copyG(r, i*csz, x, csz*c, csz); } Arr* ra = mut_fp(r); usz* rsh = arr_shAlloc(ra, rr); if (rsh) { shcpy(rsh , SH(w) , wr ); shcpy(rsh+wr, SH(x)+1, xr-1); } decG(w); decG(x); return withFill(taga(ra),xf); } base:; dec(xf); return c2(rt_select, w, x); dec_ret:; decG(w); decG(x); return r; } B select_ucw(B t, B o, B w, B x) { if (isAtm(x) || RNK(x)!=1 || isAtm(w)) return def_fn_ucw(t, o, w, x); usz xia = IA(x); usz wia = IA(w); SGetU(w) if (TI(w,elType)!=el_i32) for (usz i = 0; i < wia; i++) if (!q_i64(GetU(w,i))) return def_fn_ucw(t, o, w, x); B arg = select_c2(t, inc(w), inc(x)); B rep = c1(o, arg); if (isAtm(rep) || !eqShape(w, rep)) thrF("𝔽⌾(a⊸⊏)𝕩: Result of 𝔽 must have the same shape as 'a' (expected %H, got %H)", w, rep); #if CHECK_VALID TALLOC(bool, set, xia); for (i64 i = 0; i < xia; i++) set[i] = false; #define EQ(F) if (set[cw] && (F)) thrM("𝔽⌾(a⊸⊏): Incompatible result elements"); set[cw] = true; #define FREE_CHECK TFREE(set) SLOWIF(xia>100 && wia100 && wiare?xe:re; bool reuse = reusable(x); if (me==el_i32) { I32Arr* xn = reuse? toI32Arr(REUSE(x)) : cpyI32Arr(x); i32* xp = i32arrv_ptr(xn); rep = toI32Any(rep); i32* rp = i32any_ptr(rep); for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; // we're free to assume w is valid i32 cr = rp[i]; EQ(cr != xp[cw]); xp[cw] = cr; } decG(w); decG(rep); FREE_CHECK; return taga(xn); } else if (me==el_i8) { I8Arr* xn = reuse? toI8Arr(REUSE(x)) : cpyI8Arr(x); i8* xp = i8arrv_ptr(xn); rep = toI8Any(rep); i8* rp = i8any_ptr(rep); for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; i8 cr = rp[i]; EQ(cr != xp[cw]); xp[cw] = cr; } decG(w); decG(rep); FREE_CHECK; return taga(xn); } else if (me==el_i16) { I16Arr* xn = reuse? toI16Arr(REUSE(x)) : cpyI16Arr(x); i16* xp = i16arrv_ptr(xn); rep = toI16Any(rep); i16* rp = i16any_ptr(rep); for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; i16 cr = rp[i]; EQ(cr != xp[cw]); xp[cw] = cr; } decG(w); decG(rep); FREE_CHECK; return taga(xn); } else if (me==el_bit) { BitArr* xn = reuse? toBitArr(REUSE(x)) : cpyBitArr(x); u64* xp = bitarrv_ptr(xn); rep = taga(toBitArr(rep)); u64* rp = bitarr_ptr(rep); for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; bool cr = bitp_get(rp, i); EQ(cr != bitp_get(xp,cw)); bitp_set(xp,cw,cr); } decG(w); decG(rep); FREE_CHECK; return taga(xn); } else if (me==el_f64) { F64Arr* xn = reuse? toF64Arr(REUSE(x)) : cpyF64Arr(x); f64* xp = f64arrv_ptr(xn); rep = toF64Any(rep); f64* rp = f64any_ptr(rep); for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; f64 cr = rp[i]; EQ(cr != xp[cw]); xp[cw] = cr; } decG(w); decG(rep); FREE_CHECK; return taga(xn); } else UD; } if (reusable(x) && xe==re) { if (TY(x)==t_harr) { B* xp = harr_ptr(REUSE(x)); SGet(rep) for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; B cr = Get(rep, i); EQ(!equal(cr,xp[cw])); dec(xp[cw]); xp[cw] = cr; } decG(w); decG(rep); FREE_CHECK; return x; } } MAKE_MUT(r, xia); mut_init(r, el_or(xe, re)); MUTG_INIT(r); mut_copyG(r, 0, x, 0, xia); SGet(rep) for (usz i = 0; i < wia; i++) { i64 cw = wp[i]; if (RARE(cw<0)) cw+= (i64)xia; B cr = Get(rep, i); EQ(!equal(mut_getU(r, cw), cr)); mut_rm(r, cw); mut_setG(r, cw, cr); } decG(w); decG(rep); FREE_CHECK; return mut_fcd(r, x); } MAKE_MUT(r, xia); mut_init(r, el_or(xe, re)); MUTG_INIT(r); mut_copyG(r, 0, x, 0, xia); SGet(rep) for (usz i = 0; i < wia; i++) { i64 cw = o2i64G(GetU(w, i)); if (RARE(cw<0)) cw+= (i64)xia; B cr = Get(rep, i); EQ(!equal(mut_getU(r, cw), cr)); mut_rm(r, cw); mut_setG(r, cw, cr); } decG(w); decG(rep); FREE_CHECK; return mut_fcd(r, x); #undef EQ #undef FREE_CHECK }