From 98f6944440c2aada2ec43e80492ddbc5b783520e Mon Sep 17 00:00:00 2001 From: dzaima Date: Fri, 3 Feb 2023 18:55:14 +0200 Subject: [PATCH] =?UTF-8?q?unify=20rank=201=20and=20high-rank=20=E2=8A=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/select.c | 281 ++++++++++++++++++++++-------------------- src/builtins/slash.c | 2 +- 2 files changed, 146 insertions(+), 137 deletions(-) diff --git a/src/builtins/select.c b/src/builtins/select.c index 5fe20df1..d9902ede 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -56,9 +56,10 @@ B select_c2(B t, B w, B x) { if (isAtm(x)) thrM("โŠ: ๐•ฉ cannot be an atom"); ur xr = RNK(x); if (isAtm(w)) { + watom:; if (xr==0) thrM("โŠ: ๐•ฉ cannot be a unit"); - usz cam = SH(x)[0]; - usz wi = WRAP(o2i64(w), cam, thrF("โŠ: Indexing out-of-bounds (๐•จโ‰ก%R, %sโ‰กโ‰ ๐•ฉ)", w, cam)); + usz xn = *SH(x); + usz wi = WRAP(o2i64(w), xn, thrF("โŠ: Indexing out-of-bounds (%RโˆŠ๐•จ, %Hโ‰กโ‰ข๐•ฉ)", w, x)); if (xr==1) { B xf = getFillR(x); B xv = IGet(x, wi); @@ -86,174 +87,182 @@ B select_c2(B t, B w, B x) { } usz wia = IA(w); - B r; + Arr* r; + ur wr = RNK(w); + if (wr==0) { + B w0 = IGetU(w, 0); + if (isAtm(w0)) { + decG(w); + w = inc(w0); + goto watom; + } + } + i32 rr = xr+wr-1; if (wia==0) { - ur wr = RNK(w); + emptyRes: if (0 == *SH(x) && wr==1) { - r = incG(x); - goto dec_ret; + decG(w); + return x; } - ur rr = xr+wr-1; - Arr* ra = emptyArr(x, rr); - if (rr>1) { - ShArr* sh = m_shArr(rr); - shcpy(sh->a, SH(w), wr); - shcpy(sh->a+wr, SH(x)+1, xr-1); - arr_shSetU(ra, rr, sh); - } - r = taga(ra); - goto dec_ret; + r = emptyArr(x, rr); + if (rr<=1) goto dec_ret; + goto setsh; } B xf = getFillQ(x); + usz xn = *SH(x); + if (xn==0) goto base; + usz ria = wia * arr_csz(x); - if (xr==1) { - usz xia = IA(x); - if (xia==0) goto base; // can't just error immediately because depth 2 ๐•จ - u8 xe = TI(x,elType); - u8 we = TI(w,elType); - #if SINGELI_X86_64 - #define CPUSEL(W, NEXT) \ - if (!avx2_select_tab[4*(we-el_i8)+CTZ(xw)](wp, xp, rp, wia, xia)) thrM("โŠ: Indexing out-of-bounds"); - #define BOOL_USE_SIMD (xia<=128) - #define BOOL_SPECIAL(W) \ - if (sizeof(W)==1 && BOOL_USE_SIMD) { \ - if (!avx2_select_bool128(wp, xp, rp, wia, xia)) thrM("โŠ: Indexing out-of-bounds"); \ - goto dec_ret; \ - } - #else - #define CPUSEL(W, NEXT) \ - if (sizeof(W) >= 4) { \ - switch(xw) { default:UD; CASEW(1,u8); CASEW(2,u16); CASEW(4,u32); CASEW(8,f64); } \ - } else { \ - W* wt = NULL; \ - for (usz bl=(1<<14)/sizeof(W), i0=0, i1=0; i0wia) i1=wia; \ - W min=wp[i0], max=min; for (usz i=i0+1; imax) max=e; if (e=(i64)xia) thrF("โŠ: Indexing out-of-bounds (%iโˆŠ๐•จ, %sโ‰กโ‰ ๐•ฉ)", max, xia); \ - W* ip=wp; usz off=xia; \ - if (max>=0) { off=0; if (RARE(min<0)) { \ - if (RARE(xia > (1ULL<<(sizeof(W)*8-1)))) { w=taga(NEXT(w)); mm_free(v(r)); return select_c2(m_f64(0), w, x); } \ - if (!wt) {wt=TALLOCP(W,i1-i0);} ip=wt-i0; \ - for (usz i=i0; i> (n%8)) & 1); \ - if (i%64 == 0) { rp[i/64]=b; if (!i) break; } \ - } \ - goto dec_ret; \ - } \ - if (xe!=el_B) { \ - usz xw = elWidth(xe); \ - void* rp = m_tyarrc(&r, xw, w, el2t(xe)); \ - void* xp = tyany_ptr(x); \ - CPUSEL(W, NEXT) \ - goto dec_ret; \ - } \ - M_HARR(r, wia); \ - if (TY(x)==t_harr || TY(x)==t_hslice) { \ - B* xp = hany_ptr(x); \ - for (usz i=0; i < wia; i++) HARR_ADD(r, i, inc(xp[WRAP(wp[i], xia, thrF("โŠ: Indexing out-of-bounds (%iโˆŠ๐•จ, %sโ‰กโ‰ ๐•ฉ)", wp[i], xia))])); \ - decG(x); return HARR_FCD(r, w); \ - } SLOW2("๐•จโŠ๐•ฉ", w, x); \ - for (usz i=0; i < wia; i++) HARR_ADD(r, i, Get(x, WRAP(wp[i], xia, thrF("โŠ: Indexing out-of-bounds (%iโˆŠ๐•จ, %sโ‰กโ‰ ๐•ฉ)", wp[i], xia)))); \ - decG(x); return withFill(HARR_FCD(r,w),xf); \ - } - if (xe==el_bit && wia>=256 && !BOOL_USE_SIMD && wia/4>=xia && we!=el_bit) { - return taga(cpyBitArr(select_c2(m_f64(0), w, taga(cpyI8Arr(x))))); - } - SGet(x) - if (we==el_bit) { + #define CASEW(S, E) case S: for (usz i=0; i= 4) { \ + switch(xl) { default:UD; CASEW(3,u8); CASEW(4,u16); CASEW(5,u32); CASEW(6,u64); } \ + } else { \ + W* wt = NULL; \ + for (usz bl=(1<<14)/sizeof(W), i0=0, i1=0; i0wia) i1=wia; \ + W min=wp[i0], max=min; for (usz i=i0+1; imax) max=e; if (e=(i64)xn) thrF("โŠ: Indexing out-of-bounds (%iโˆŠ๐•จ, %sโ‰กโ‰ ๐•ฉ)", max, xn); \ + W* ip=wp; usz off=xn; \ + if (max>=0) { off=0; if (RARE(min<0)) { \ + if (RARE(xn > (1ULL<<(sizeof(W)*8-1)))) { w=taga(NEXT(w)); mm_free((Value*)r); return select_c2(m_f64(0), w, x); } \ + if (!wt) {wt=TALLOCP(W,i1-i0);} ip=wt-i0;\ + for (usz i=i0; i=256 && xl<3 && wia/4>=xia && we!=el_bit) { + return taga(cpyBitArr(select_c2(m_f64(0), w, taga(cpyI8Arr(x))))); + } + + + #define TYPE(W, NEXT) { W* wp = W##any_ptr(w); \ + if (xl==0) { u64* xp=bitarr_ptr(x); \ + u64* rp; r = m_bitarrp(&rp, ria); \ + BOOL_SPECIAL(W) \ + u64 b=0; \ + for (usz i = wia; ; ) { \ + i--; \ + usz n = WRAP(wp[i], xn, thrF("โŠ: Indexing out-of-bounds (%iโˆŠ๐•จ, %sโ‰กโ‰ ๐•ฉ)", wp[i], xn)); \ + b = 2*b + ((((u8*)xp)[n/8] >> (n%8)) & 1); \ + if (i%64 == 0) { rp[i/64]=b; if (!i) break; } \ + } \ + goto setsh; \ + } \ + if (xe!=el_B) { \ + if (xl<3 || xl==7) goto generic_l; \ + void* rp = m_tyarrlp(&r, xl-3, ria, arrNewType(TY(x))); \ + void* xp = tyany_ptr(x); \ + CPUSEL(W, NEXT) \ + goto setsh; \ + } \ + if (xl!=6) goto generic_l; \ + M_HARR(ra, wia); B* xp = arr_bptr(x); \ + SLOWIF(xp==NULL) SLOW2("๐•จโŠ๐•ฉ", w, x); \ + if (xp!=NULL) { for (usz i=0; iUR_MAX) thrF("โŠ: Result rank too large (%iโ‰ก=๐•จ, %iโ‰ก=๐•ฉ)", wr, xr); - usz csz = arr_csz(x); - usz cam = SH(x)[0]; - MAKE_MUT(r, wia*csz); mut_init(r, TI(x,elType)); - MUTG_INIT(r); - for (usz i = 0; i < wia; i++) { - B cw = GetU(w, i); // assumed number from previous squeeze - usz c = WRAP(o2i64(cw), cam, { mut_pfree(r, i*csz); thrF("โŠ: Indexing out-of-bounds (%RโˆŠ๐•จ, %sโ‰กโ‰ ๐•ฉ)", cw, cam); }); - mut_copyG(r, i*csz, x, csz*c, csz); - } - Arr* ra = mut_fp(r); - usz* rsh = arr_shAlloc(ra, rr); - if (rsh) { - shcpy(rsh , SH(w) , wr ); - shcpy(rsh+wr, SH(x)+1, xr-1); - } - decG(w); decG(x); - return withFill(taga(ra),xf); } + #undef CASE + #undef CASEW + base:; dec(xf); return c2rt(select, w, x); + generic_l: { + if (xia==0) goto emptyRes; + SLOW2("๐•จโŠ๐•ฉ", w, x); + SGetU(w) + usz csz = arr_csz(x); + MAKE_MUT(rm, ria); mut_init(rm, TI(x,elType)); + MUTG_INIT(rm); + for (usz i = 0; i < wia; i++) { + B cw = GetU(w, i); // assumed number from previous squeeze + usz c = WRAP(o2i64(cw), xn, { mut_pfree(rm, i*csz); thrF("โŠ: Indexing out-of-bounds (%RโˆŠ๐•จ, %Hโ‰กโ‰ข๐•ฉ)", cw, x); }); + mut_copyG(rm, i*csz, x, csz*c, csz); + } + r = a(withFill(mut_fv(rm), xf)); + goto setsh; + } + + + + setsh: + if (rr>1) { + if (rr > UR_MAX) thrF("โŠ: Result rank too large (%iโ‰ก=๐•จ, %iโ‰ก=๐•ฉ)", wr, xr); + ShArr* sh = m_shArr(rr); + shcpy(sh->a, SH(w), wr); + shcpy(sh->a+wr, SH(x)+1, xr-1); + arr_shSetU(r, rr, sh); + } else { + arr_shVec(r); + } + dec_ret:; - decG(w); decG(x); return r; + decG(w); decG(x); return taga(r); } + B select_ucw(B t, B o, B w, B x) { if (isAtm(x) || RNK(x)!=1 || isAtm(w)) return def_fn_ucw(t, o, w, x); usz xia = IA(x); diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 9a90c709..7f789b65 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -721,7 +721,7 @@ B slash_c2(B t, B w, B x) { } } else { u8 xk = xl-3; - void* rv = m_tyarrv(&r, 1<sh = rsh; ra->ia = s*arr_csz(x); } void* xv = tyany_ptr(x); if ((xk<3? s/64 : s/32) <= wia) { // Sparse case: use both types