#include "../core.h" #include "../utils/mut.h" #include "../utils/talloc.h" #include "../builtins.h" #ifdef __BMI2__ #include #if USE_VALGRIND #define DBG_VG_SLASH 0 u64 loadMask(u64* p, u64 unk, u64 exp, u64 i, u64 pos) { // #if DBG_VG_SLASH // if (pos==0) printf("index %2ld, got %016lx\n", i, p[i]); // #endif if (pos==0) return ~(p[i]^exp); u64 res = loadMask(p, unk, exp, i, pos<<1); if (unk&pos) res&= loadMask(p, unk, exp, i|pos, pos<<1); return res; } NOINLINE u64 vg_load64(u64* p, u64 i) { u64 unk = ~vg_getDefined_u64(i); u64 res = p[vg_withDefined_u64(i, ~0ULL)]; // result value will always be the proper indexing operation i32 undefCount = POPC(unk); if (undefCount>0) { if (undefCount>8) err("too many unknown bits in index of vg_load64"); res = vg_withDefined_u64(res, loadMask(p, unk, res, i & ~unk, 1)); } #if DBG_VG_SLASH vg_printDefined_u64("idx", i); vg_printDefined_u64("res", res); #endif return res; } NOINLINE u64 vg_pext_u64(u64 src, u64 mask) { u64 maskD = vg_getDefined_u64(mask); u64 r = vg_undef_u64(0); i32 ri = 0; u64 undefMask = 0; for (i32 i = 0; i < 64; i++) { u64 c = 1ull<>i)&1) { r|= (c&1) << i; c>>= 1; } } #if DBG_VG_SLASH printf("pdep:\n"); vg_printDefined_u64("src", src); vg_printDefined_u64("msk", mask); vg_printDefined_u64("res", r); vg_printDefined_u64("exp", _pdep_u64(src, mask)); #endif return r; } NOINLINE u64 rand_popc64(u64 x) { u64 def = vg_getDefined_u64(x); if (def==~0ULL) return POPC(x); i32 min = POPC(x & def); i32 diff = POPC(~def); i32 res = min + vgRand64Range(diff); #if DBG_VG_SLASH printf("popc:\n"); vg_printDefined_u64("x", x); printf("popc in %d-%d; res: %d\n", min, min+diff, res); #endif return res; } #define _pext_u32 vg_pext_u64 #define _pext_u64 vg_pext_u64 #define _pdep_u32 vg_pdep_u64 #define _pdep_u64 vg_pdep_u64 #else #define vg_load64(p, i) p[i] #define rand_popc64(X) POPC(X) #endif void storeu_u64(u64* p, u64 v) { memcpy(p, &v, 8); } u64 loadu_u64(u64* p) { u64 v; memcpy(&v, p, 8); return v; } #if SINGELI #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-variable" #include "../singeli/gen/slash.c" #pragma GCC diagnostic pop #endif #endif #if SINGELI extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3); extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3); extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3); #define avx2_scan_pluswrap_u64(V0,V1,V2,V3) for (usz i=k; i>=1; } \ } \ } while (0) // Sparse Where with branching #define WHERE_SPARSE(X,R,S,I0,COND) do { \ for (usz ii=I0, j=0; j>=24; p=u&bsp_mask; buf[j]+=(3*bsp_top)|p; j+=POPC(p); u>>=24; p=u ; buf[j]+=(3*bsp_top)|p; j+=POPC(p); } return j; } #define BSP_WRITE(BUF, DST, SUM, OFF, CLEAR) \ u64 t=((u64)OFF<<21)-2*bsp_top; \ for (usz j=0; j>24) + CTZ((u32)t); \ t &= t-1; \ } static void bsp_block_u32(u64* src, u32* dst, usz len, usz sum, usz off) { for (usz j=0; j len-i) b = len-i; usz bs = bsp_fill(src+i/64, buf, b); BSP_WRITE(buf, dst, bs, i, buf[j]=0;); buf[bs]=0; dst+= bs; } TFREE(buf); } static void where_block_u16(u64* src, u16* dst, usz len, usz sum) { assert(len <= bsp_max); #if SINGELI && defined(__BMI2__) if (sum >= len/8) bmipopc_1slash16(src, (i16*)dst, len); #else if (sum >= len/4+len/8) WHERE_DENSE(src, dst, len, 0); #endif else if (sum >= len/128) { u32* buf = (u32*)dst; assert(sum*2 <= len); for (usz j=0; j0 || csz%8==0) { // Full bytes u64 width = xl==0 ? csz/8 : csz << (xl-3); u8* xp; u8* rp; bool is_B = TI(x,elType) == el_B; HArr_p rh; if (!is_B) { xp = tyany_ptr(x); rp = m_tyarrv(&r,width,wsum,xt); } else { xp = (u8*)arr_bptr(x); usz ria = wsum*csz; if (xp != NULL) { rh = m_harrUv(ria); rp = (u8*)rh.a; } else { SLOW2("š•Ø/š•©", w, x); M_HARR(rp, ria) SGet(x) for (usz i = 0; i < wia; i++) if (bitp_get(wp,i)) { for (usz j = 0; j < csz; j++) HARR_ADDA(rp, Get(x,i*csz+j)); } return withFill(HARR_FV(rp), getFillQ(x)); } } #define MEM_CPY(R,RI,X,XI,L) memcpy(R+RI, X+XI, L) COMPRESS_GROUP(MEM_CPY) #undef MEM_CPY if (is_B) { for (usz i = 0; i < wsum*csz; i++) inc(((B*)rp)[i]); r = withFill(rh.b, getFillQ(x)); IA(r) = wsum; // Shape-setting code at end of compress expects this } } else { // Bits usz width = csz; u64* xp = tyany_ptr(x); u64* rp; r = m_bitarrv(&rp,wsum*width); IA(r) = wsum; COMPRESS_GROUP(bit_cpy) } return r; } static B where(B x, usz xia, u64 s) { B r; u64* xp = bitarr_ptr(x); usz q=xia%64; if (q) xp[xia/64] &= ((u64)1<= xia/8) { i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 16); bmipopc_1slash16(xp, rp, xia); } #else if (s >= xia/4+xia/8) { i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 2); WHERE_DENSE(xp, rp, xia, 0); } #endif else { i16* rp; r=m_i16arrv(&rp,s); if (s >= xia/128) { bsp_u16(xp, (u16*)rp, xia, s); } else { WHERE_SPARSE(xp, rp, s, 0, RARE); } } } else if (xia <= (usz)I32_MAX+1) { #if SINGELI && defined(__BMI2__) i32* rp; r = m_i32arrv(&rp, s); #else i32* rp = m_tyarrvO(&r, 4, s, t_i32arr, 4); #endif usz b = bsp_max; TALLOC(i16, buf, b); i32* rq = rp; for (usz i=0; ixia-i) { b = xia-i; bs = s-(rq-rp); } else { bs = bit_sum(xp,b); } #if SINGELI && defined(__BMI2__) if (bs >= b/8+b/16) { bmipopc_1slash16(xp, buf, b); for (usz j=0; j= b/2) { i32* rs=rq; WHERE_DENSE(xp, rs, b, i); } #endif else if (bs >= b/256) { bsp_block_u32(xp, (u32*)rq, b, bs, i); } else { WHERE_SPARSE(xp-i/64, rq, bs, i/64, RARE); } rq+= bs; xp+= b/64; } TFREE(buf); } else { f64* rp; r = m_f64arrv(&rp, s); usz b = bsp_max; TALLOC(u16, buf, b); f64* rp0 = rp; for (usz i=0; ixia-i) { b=xia-i; bs=s-(rp-rp0); } else { bs=bit_sum(xp,b); } where_block_u16(xp, buf, b, bs); for (usz j=0; j>63; } if (r > max) return 0; } return 1; } extern B take_c2(B, B, B); static B compress(B w, B x, usz wia, u8 xl, u8 xt) { u64* wp = bitarr_ptr(w); u64 we = 0; usz ie = wia/64; usz q=wia%64; if (q) we = wp[ie] &= ((u64)1<1) return take_c2(m_f64(0), m_f64(0), inc(x)); u8 xe = TI(x,elType); if (xe != el_B) return elNum(xe)? emptyIVec() : emptyCVec(); B xf = getFillQ(x); return q_N(xf)? emptyHVec() : m_emptyFVec(xf); } we = wp[--ie]; } usz wia0 = wia; wia = 64*(ie+1) - CLZ(we); usz wsum = bit_sum(wp, wia); if (wsum == wia0) return inc(x); B r; switch(xl) { default: r = compress_grouped(wp, x, wia, wsum, xt); break; case 0: { u64* xp = bitarr_ptr(x); u64* rp; #if SINGELI && defined(__BMI2__) r = m_bitarrv(&rp,wsum+128); a(r)->ia = wsum; u64 cw = 0; // current word u64 ro = 0; // offset in word where next bit should be written; never 64 for (usz i=0; i=64) { *(rp++) = cw; cw = ro? v>>(64-ro) : 0; } ro = ro2&63; } if (ro) *rp = cw; #else r = m_bitarrv(&rp,wsum); for (usz i=0, ri=0; iwia-i) { b=wia-i; bs=wsum-(rp-rp0); } \ else { bs=bit_sum(wp,b); } \ where_block_u16(wp, (u16*)buf, b, bs); \ for (usz j=0; j=wia/8 && groups_lt(wp,wia, wia/16)) r = compress_grouped(wp, x, wia, wsum, xt); \ else { T* xp=tyany_ptr(x); T* rp=m_tyarrv(&r,sizeof(T),wsum,xt); COMPRESS_BLOCK(T); } case 5: BLOCK_OR_GROUPED(i32) break; case 6: if (TI(x,elType)!=el_B) { BLOCK_OR_GROUPED(u64) } else { B xf = getFillQ(x); B* xp = arr_bptr(x); if (xp!=NULL) { HArr_p rh = m_harrUv(wsum); B *rp = rh.a; COMPRESS_BLOCK(B); for (usz i=0; i 1) { Arr* ra=a(r); SPRNK(ra,xr); usz* sh = PSH(ra) = m_shArr(xr)->a; sh[0] = PIA(ra); PIA(ra) *= arr_csz(x); shcpy(sh+1, SH(x)+1, xr-1); } return r; } // Replicate using plus/max/xor-scan #define SCAN_CORE(WV, UPD, SET, SCAN) \ usz b = 1<<10; \ for (usz k=0, j=0, ij=WV; ; ) { \ usz e = b>63, js=-(xx&1); xx^=xx<<1; \ for (usz k=0, j=0, ij=WV; ; ) { \ usz e = b>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \ rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=WV; \ } \ for (usz i=k/64; i>63); \ if (e==s) {break;} k=e; \ } // Basic boolean loop with overwriting #define BOOL_REP_OVER(WV, LEN) \ u64 ri=0, rc=0, xc=0; usz j=0; \ for (usz i = 0; i < LEN; i++) { \ u64 v = -(u64)bitp_get(xp,i); \ rc ^= (v^xc) << (ri%64); \ xc = v; \ ri += WV; usz e = ri/64; \ if (j < e) { \ rp[j++] = rc; \ while (j < e) rp[j++] = v; \ rc = v; \ } \ } \ if (ri%64) rp[j] = rc; extern B rt_slash; B slash_c1(B t, B x) { if (RARE(isAtm(x)) || RARE(RNK(x)!=1)) thrF("/: Argument must have rank 1 (%H ≔ ā‰¢š•©)", x); u64 s = usum(x); if (s>=USZ_MAX) thrOOM(); if (s==0) { decG(x); return emptyIVec(); } usz xia = IA(x); B r; u8 xe = TI(x,elType); if (xe!=el_bit && s<=xia) { x = num_squeezeChk(x); xe = TI(x,elType); } if (xe==el_bit) { r = where(x, xia, s); } else if (RARE(xia > (usz)I32_MAX+1)) { SGetU(x) f64* rp; r = m_f64arrv(&rp, s); usz ri = 0; for (usz i = 0; i < xia; i++) { usz c = o2s(GetU(x, i)); for (usz j = 0; j < c; j++) rp[ri++] = i; } } else if (RARE(xe > el_i32)) { i32* rp; r = m_i32arrv(&rp, s); SLOW1("/š•©", x); SGetU(x) for (u64 i = 0; i < xia; i++) { usz c = o2s(GetU(x, i)); for (u64 j = 0; j < c; j++) *rp++ = i; } } else { if (s/32 <= xia) { // Sparse case: type of x matters #define SPARSE_IND(T) T* xp = T##any_ptr(x); IND_BY_SCAN i32* rp; r = m_i32arrv(&rp, s); if (xe == el_i8 ) { SPARSE_IND(i8 ); } else if (xe == el_i16) { SPARSE_IND(i16); } else { SPARSE_IND(i32); } #undef SPARSE_IND } else { // Dense case: only result type matters #define DENSE_IND(T) \ T* rp; r = m_##T##arrv(&rp, s); \ for (u64 i = 0; i < xia; i++) { \ i32 c = xp[i]; \ for (i32 j = 0; j < c; j++) *rp++ = i; \ } if (xe < el_i32) x = taga(cpyI32Arr(x)); i32* xp = i32any_ptr(x); while (xia>0 && !xp[xia-1]) xia--; if (xia <= 128) { DENSE_IND(i8 ); } else if (xia <= 32768) { DENSE_IND(i16); } else { DENSE_IND(i32); } #undef DENSE_IND } } decG(x); return r; } B slash_c2(B t, B w, B x) { i32 wv = -1; usz wia; if (isArr(w)) { if (depth(w)>1) goto base; ur wr = RNK(w); if (wr>1) thrF("/: Simple š•Ø must have rank 0 or 1 (%i≔=š•Ø)", wr); if (wr<1) { B v=IGet(w, 0); decG(w); w=v; goto atom; } wia = IA(w); if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); } } else { atom: if (!q_i32(w)) goto base; wv = o2i(w); } if (isAtm(x) || RNK(x)==0) thrM("/: š•© must have rank at least 1 for simple š•Ø"); ur xr = RNK(x); usz xlen = *SH(x); u8 xl = cellWidthLog(x); u8 xt = arrNewType(TY(x)); B r; if (wv < 0) { // Array w if (RARE(wia!=xlen)) thrF("/: Lengths of components of š•Ø must match š•© (%s ≠ %s)", wia, xlen); u8 we = TI(w,elType); if (!elInt(we)) { w=any_squeeze(w); we=TI(w,elType); if (!elInt(we)) goto slow; } if (we==el_bit) { wbool: r = compress(w, x, wia, xl, xt); goto decWX_ret; } if (xl>6 || (xl<3 && xl!=0)) goto base; u64 s = usum(w); if (s<=wia) { w=num_squeezeChk(w); we=TI(w,elType); if (we==el_bit) goto wbool; } if (RARE(TI(x,elType)==el_B)) { // Slow case slow: if (xr > 1) goto base; SLOW2("š•Ø/š•©", w, x); B xf = getFillQ(x); MAKE_MUT(r0, s) mut_init(r0, el_B); MUTG_INIT(r0); SGetU(w) SGetU(x) usz ri = 0; for (usz i = 0; i < wia; i++) { usz c = o2s(GetU(w, i)); if (c) { mut_fillG(r0, ri, GetU(x, i), c); ri+= c; } } r = withFill(mut_fv(r0), xf); decWX_ret: decG(w); decX_ret: decG(x); return r; } // Make shape if needed; all cases below use it usz* rsh = NULL; if (xr > 1) { rsh = m_shArr(xr)->a; rsh[0] = s; shcpy(rsh+1, SH(x)+1, xr-1); } if (xl == 0) { u64* xp = bitarr_ptr(x); u64* rp; r = m_bitarrv(&rp, s); if (rsh) { SPRNK(a(r),xr); SH(r) = rsh; } if (s/256 <= wia) { #define SPARSE_REP(T) T* wp=T##any_ptr(w); BOOL_REP_XOR_SCAN(wp[j]) if (we==el_i8 ) { SPARSE_REP(i8 ); } else if (we==el_i16) { SPARSE_REP(i16); } else { SPARSE_REP(i32); } #undef SPARSE_REP } else { if (we < el_i32) w = taga(cpyI32Arr(w)); i32* wp = i32any_ptr(w); BOOL_REP_OVER(wp[i], wia) } } else { u8 xk = xl-3; void* rv = m_tyarrv(&r, 1<0 && !wp[wia-1]) wia--; switch (xk) { default: UD; CASE(0,u8) CASE(1,u16) CASE(2,u32) CASE(3,u64) } #undef CASE } } goto decWX_ret; } else { if (wv <= 1) { if (wv < 0) thrM("/: š•Ø cannot be negative"); return wv ? x : taga(arr_shVec(TI(x,slice)(x, 0, 0))); } if (xlen == 0) return x; usz s = xlen * wv; if (xl>6 || (xl<3 && xl!=0) || TI(x,elType)==el_B) { if (xr != 1) goto base; SLOW2("š•Ø/š•©", w, x); B xf = getFillQ(x); HArr_p r0 = m_harrUv(s); SGetU(x) for (usz i = 0; i < xlen; i++) { B cx = incBy(GetU(x, i), wv); for (i64 j = 0; j < wv; j++) *r0.a++ = cx; } r = withFill(r0.b, xf); goto decX_ret; } if (xl == 0) { u64* xp = bitarr_ptr(x); u64* rp; r = m_bitarrv(&rp, s); if (wv <= 128) { BOOL_REP_XOR_SCAN(wv) } else { BOOL_REP_OVER(wv, xlen) } goto decX_ret; } else { u8 xk = xl-3; void* rv = m_tyarrv(&r, 1< 1) { usz* rsh = m_shArr(xr)->a; rsh[0] = s; shcpy(rsh+1, SH(x)+1, xr-1); Arr* ra=a(r); SPRNK(ra,xr); PSH(ra)=rsh; PIA(ra)=s*arr_csz(x); } goto decX_ret; } base: return c2(rt_slash, w, x); } B slash_im(B t, B x) { if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be an array"); u8 xe = TI(x,elType); usz xia = IA(x); if (xia==0) { decG(x); return emptyIVec(); } switch(xe) { default: UD; case el_bit: { usz sum = bit_sum(bitarr_ptr(x), xia); usz ria = 1 + (sum>0); f64* rp; B r = m_f64arrv(&rp, ria); rp[sum>0] = sum; rp[0] = xia - sum; decG(x); return num_squeezeChk(r); } #define CASE_SMALL(N) \ case el_i##N: { \ i##N* xp = i##N##any_ptr(x); \ usz m=1<xp[a-1]) a++; \ max=xp[a-1]; \ if (a==xia) { /* Sorted unique argument */ \ usz ria = max + 1; \ u64* rp; r = m_bitarrv(&rp, ria); \ for (usz i=0; imax) max=c; } \ if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \ usz ria = max+1; \ i##N* rp; r = m_i##N##arrv(&rp, ria); for (usz i=0; im/2) thrM("/⁼: Argument cannot contain negative numbers"); \ i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; imax?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); } usz ria = max+1; if (i==xia) { u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; imax?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); } usz ria = max+1; if (ria==0) thrOOM(); if (i==xia) { u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; ia; } usz i,j; B r; i64 max=-1; for (i = 0; i < xia; i++) { i64 c=o2i64(xp[i]); if (c<=max) break; max=c; } for (j = i; j < xia; j++) { i64 c=o2i64(xp[j]); max=c>max?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); } if (max > USZ_MAX-1) thrOOM(); usz ria = max+1; if (i==xia) { u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; ifns->elType!=el_i32) mut_to(r, el_i32); i32* rp = r->ai32; x = toI32Any(x); i32* xp = i32any_ptr(x); rep = toI32Any(rep); i32* np = i32any_ptr(rep); for (usz i = 0; i < ia; i++) { bool v = bitp_get(d, i); i32 nc = np[repI]; i32 xc = xp[i]; rp[i] = v? nc : xc; repI+= v; } } else { MUTG_INIT(r); for (usz i = 0; i < ia; i++) mut_setG(r, i, bitp_get(d, i)? Get(rep,repI++) : Get(x,i)); } } else { SGetU(rep) MUTG_INIT(r); for (usz i = 0; i < ia; i++) { i32 cw = o2iG(GetU(w, i)); if (cw) { B cr = Get(rep,repI); if (CHECK_VALID) for (i32 j = 1; j < cw; j++) if (!equal(GetU(rep,repI+j), cr)) { mut_pfree(r,i); thrM("š”½āŒ¾(a⊸/): Incompatible result elements"); } mut_setG(r, i, cr); repI+= cw; } else mut_setG(r, i, Get(x,i)); } } decG(w); decG(rep); return mut_fcd(r, x); } void slash_init() { c(BFn,bi_slash)->im = slash_im; c(BFn,bi_slash)->ucw = slash_ucw; }