#include "../core.h" #include "../utils/mut.h" #include "../utils/talloc.h" #include "../builtins.h" #ifdef __BMI2__ #include #if USE_VALGRIND #define DBG_VG_SLASH 0 u64 loadMask(u64* p, u64 unk, u64 exp, u64 i, u64 pos) { // #if DBG_VG_SLASH // if (pos==0) printf("index %2ld, got %016lx\n", i, p[i]); // #endif if (pos==0) return ~(p[i]^exp); u64 res = loadMask(p, unk, exp, i, pos<<1); if (unk&pos) res&= loadMask(p, unk, exp, i|pos, pos<<1); return res; } NOINLINE u64 vg_load64(u64* p, u64 i) { u64 unk = ~vg_getDefined_u64(i); u64 res = p[vg_withDefined_u64(i, ~0ULL)]; // result value will always be the proper indexing operation i32 undefCount = POPC(unk); if (undefCount>0) { if (undefCount>8) err("too many unknown bits in index of vg_load64"); res = vg_withDefined_u64(res, loadMask(p, unk, res, i & ~unk, 1)); } #if DBG_VG_SLASH vg_printDefined_u64("idx", i); vg_printDefined_u64("res", res); #endif return res; } NOINLINE u64 vg_pext_u64(u64 src, u64 mask) { u64 maskD = vg_getDefined_u64(mask); u64 r = vg_undef_u64(0); i32 ri = 0; u64 undefMask = 0; for (i32 i = 0; i < 64; i++) { u64 c = 1ull<>i)&1) { r|= (c&1) << i; c>>= 1; } } #if DBG_VG_SLASH printf("pdep:\n"); vg_printDefined_u64("src", src); vg_printDefined_u64("msk", mask); vg_printDefined_u64("res", r); vg_printDefined_u64("exp", _pdep_u64(src, mask)); #endif return r; } NOINLINE u64 rand_popc64(u64 x) { u64 def = vg_getDefined_u64(x); if (def==~0ULL) return POPC(x); i32 min = POPC(x & def); i32 diff = POPC(~def); i32 res = min + vgRand64Range(diff); #if DBG_VG_SLASH printf("popc:\n"); vg_printDefined_u64("x", x); printf("popc in %d-%d; res: %d\n", min, min+diff, res); #endif return res; } #define _pext_u32 vg_pext_u64 #define _pext_u64 vg_pext_u64 #define _pdep_u32 vg_pdep_u64 #define _pdep_u64 vg_pdep_u64 #else #define vg_load64(p, i) p[i] #define rand_popc64(X) POPC(X) #endif void storeu_u64(u64* p, u64 v) { memcpy(p, &v, 8); } u64 loadu_u64(u64* p) { u64 v; memcpy(&v, p, 8); return v; } #if SINGELI #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-variable" #include "../singeli/gen/slash.c" #pragma GCC diagnostic pop #endif #endif // Dense Where, still significantly worse than SIMD // Assumes modifiable DST #define WHERE_DENSE(SRC, DST, LEN, OFF) do { \ for (usz ii=0; ii<(LEN+7)/8; ii++) { \ u8 v = ((u8*)SRC)[ii]; \ for (usz k=0; k<8; k++) { *DST=OFF+8*ii+k; DST+=v&1; v>>=1; } \ } \ } while (0) // Sparse Where with branching #define WHERE_SPARSE(X,R,S,I0,COND) do { \ for (usz ii=I0, j=0; j>=24; p=u&bsp_mask; buf[j]+=(3*bsp_top)|p; j+=POPC(p); u>>=24; p=u ; buf[j]+=(3*bsp_top)|p; j+=POPC(p); } return j; } #define BSP_WRITE(BUF, DST, SUM, OFF, CLEAR) \ u64 t=((u64)OFF<<21)-2*bsp_top; \ for (usz j=0; j>24) + CTZ((u32)t); \ t &= t-1; \ } static void bsp_block_u32(u64* src, u32* dst, usz len, usz sum, usz off) { for (usz j=0; j len-i) b = len-i; usz bs = bsp_fill(src+i/64, buf, b); BSP_WRITE(buf, dst, bs, i, buf[j]=0;); buf[bs]=0; dst+= bs; } TFREE(buf); } static void where_block_u16(u64* src, u16* dst, usz len, usz sum) { assert(len <= bsp_max); #if SINGELI && defined(__BMI2__) if (sum >= len/8) bmipopc_1slash16(src, (i16*)dst, len); #else if (sum >= len/4+len/8) WHERE_DENSE(src, dst, len, 0); #endif else if (sum >= len/128) { u32* buf = (u32*)dst; assert(sum*2 <= len); for (usz j=0; j= xia/8) { i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 16); bmipopc_1slash16(xp, rp, xia); } #else if (s >= xia/4+xia/8) { i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 2); WHERE_DENSE(xp, rp, xia, 0); } #endif else { i16* rp; r=m_i16arrv(&rp,s); if (s >= xia/128) { bsp_u16(xp, (u16*)rp, xia, s); } else { WHERE_SPARSE(xp, rp, s, 0, RARE); } } } else if (xia <= (usz)I32_MAX+1) { #if SINGELI && defined(__BMI2__) i32* rp; r = m_i32arrv(&rp, s); #else i32* rp = m_tyarrvO(&r, 4, s, t_i32arr, 4); #endif usz b = bsp_max; TALLOC(i16, buf, b); i32* rq = rp; for (usz i=0; ixia-i) { b = xia-i; bs = s-(rq-rp); } else { bs = bit_sum(xp,b); } #if SINGELI && defined(__BMI2__) if (bs >= b/8+b/16) { bmipopc_1slash16(xp, buf, b); for (usz j=0; j= b/2) { i32* rs=rq; WHERE_DENSE(xp, rs, b, i); } #endif else if (bs >= b/256) { bsp_block_u32(xp, (u32*)rq, b, bs, i); } else { WHERE_SPARSE(xp-i/64, rq, bs, i/64, RARE); } rq+= bs; xp+= b/64; } TFREE(buf); } else { f64* rp; r = m_f64arrv(&rp, s); usz b = bsp_max; TALLOC(u16, buf, b); f64* rp0 = rp; for (usz i=0; ixia-i) { b=xia-i; bs=s-(rp-rp0); } else { bs=bit_sum(xp,b); } where_block_u16(xp, buf, b, bs); for (usz j=0; jia = wsum; u64 cw = 0; // current word u64 ro = 0; // offset in word where next bit should be written; never 64 for (usz i=0; i=64) { *(rp++) = cw; cw = ro? v>>(64-ro) : 0; } ro = ro2&63; } if (ro) *rp = cw; #else r = m_bitarrv(&rp,wsum); for (usz i=0, ri=0; iwia-i) { b=wia-i; bs=wsum-(rp-rp0); } \ else { bs=bit_sum(wp,b); } \ where_block_u16(wp, (u16*)buf, b, bs); \ for (usz j=0; j=wia/CUTOFF) { DENSE; } \ else { rp=m_tyarrv(&r,W/8,wsum,el2t(xe)); COMPRESS_BLOCK(i##W); } \ break; } #if SINGELI case el_i8: case el_c8: WITH_SPARSE( 8, 32, rp=m_tyarrvO(&r,1,wsum,el2t(xe), 8); bmipopc_2slash8 (wp, xp, rp, wia)) case el_i16:case el_c16: WITH_SPARSE(16, 16, rp=m_tyarrvO(&r,2,wsum,el2t(xe), 16); bmipopc_2slash16(wp, xp, rp, wia)) #else case el_i8: case el_c8: WITH_SPARSE( 8, 2, rp=m_tyarrv(&r,1,wsum,el2t(xe)); for (usz i=0; i=USZ_MAX) thrOOM(); if (s==0) { decG(x); return emptyIVec(); } usz xia = IA(x); B r; u8 xe = TI(x,elType); if (xe==el_bit) { r = where(x, xia, s); } else if (RARE(xia>=I32_MAX)) { SGetU(x) f64* rp; r = m_f64arrv(&rp, s); usz ri = 0; for (usz i = 0; i < xia; i++) { usz c = o2s(GetU(x, i)); for (usz j = 0; j < c; j++) rp[ri++] = i; } } else { i32* rp; r = m_i32arrv(&rp, s); if (xe==el_i8) { i8* xp = i8any_ptr(x); while (xia>0 && !xp[xia-1]) xia--; for (u64 i = 0; i < xia; i++) { i32 c = xp[i]; if (LIKELY(c==0 || c==1)) { *rp = i; rp+= c; } else { for (i32 j = 0; j < c; j++) *rp++ = i; } } } else if (xe==el_i32) { i32* xp = i32any_ptr(x); while (xia>0 && !xp[xia-1]) xia--; for (u64 i = 0; i < xia; i++) { i32 c = xp[i]; if (LIKELY(c==0 || c==1)) { *rp = i; rp+= c; } else { for (i32 j = 0; j < c; j++) *rp++ = i; } } } else { SLOW1("/𝕩", x); SGetU(x) for (u64 i = 0; i < xia; i++) { usz c = o2s(GetU(x, i)); for (u64 j = 0; j < c; j++) *rp++ = i; } } } decG(x); return r; } B slash_c2(B t, B w, B x) { if (isArr(x) && RNK(x)==1 && isArr(w) && RNK(w)==1 && depth(w)==1) { usz wia = IA(w); usz xia = IA(x); if (RARE(wia!=xia)) { if (wia==0) { decG(w); return x; } thrF("/: Lengths of components of 𝕨 must match 𝕩 (%s ≠ %s)", wia, xia); } B xf = getFillQ(x); if (TI(w,elType)==el_bit) { B r = compress(w, x, wia, xf); decG(w); decG(x); return r; } #define CASE(WT,XT) if (TI(x,elType)==el_##XT) { \ XT* xp = XT##any_ptr(x); \ XT* rp; B r = m_##XT##arrv(&rp, wsum); \ if (or<2) for (usz i = 0; i < wia; i++) { \ *rp = xp[i]; \ rp+= wp[i]; \ } else for (usz i = 0; i < wia; i++) { \ WT cw = wp[i]; XT cx = xp[i]; \ for (i64 j = 0; j < cw; j++) *rp++ = cx; \ } \ decG(w); decG(x); return r; \ } #define TYPED(WT,SIGN) { \ WT* wp = WT##any_ptr(w); \ while (wia>0 && !wp[wia-1]) wia--; \ i64 wsum = 0; \ u32 or = 0; \ for (usz i = 0; i < wia; i++) { \ wsum+= wp[i]; \ or|= (u32)wp[i]; \ } \ if (or>>SIGN) thrM("/: 𝕨 must consist of natural numbers"); \ if (TI(x,elType)==el_bit) { \ u64* xp = bitarr_ptr(x); u64 ri=0; \ u64* rp; B r = m_bitarrv(&rp, wsum); \ if (or<2) for (usz i = 0; i < wia; i++) { \ bitp_set(rp, ri, bitp_get(xp,i)); \ ri+= wp[i]; \ } else for (usz i = 0; i < wia; i++) { \ WT cw = wp[i]; bool cx = bitp_get(xp,i); \ for (i64 j = 0; j < cw; j++) bitp_set(rp, ri++, cx); \ } \ decG(w); decG(x); return r; \ } \ CASE(WT,i8) CASE(WT,i16) CASE(WT,i32) CASE(WT,f64) \ SLOW2("𝕨/𝕩", w, x); \ M_HARR(r, wsum) SGetU(x) \ for (usz i = 0; i < wia; i++) { \ i32 cw = wp[i]; if (cw==0) continue; \ B cx = incBy(GetU(x, i), cw); \ for (i64 j = 0; j < cw; j++) HARR_ADDA(r, cx);\ } \ decG(w); decG(x); \ return withFill(HARR_FV(r), xf); \ } if (TI(w,elType)==el_i8 ) TYPED(i8,7); if (TI(w,elType)==el_i32) TYPED(i32,31); #undef TYPED #undef CASE SLOW2("𝕨/𝕩", w, x); u64 ria = usum(w); if (ria>=USZ_MAX) thrOOM(); M_HARR(r, ria) SGetU(w) SGetU(x) for (usz i = 0; i < wia; i++) { usz c = o2s(GetU(w, i)); if (c) { B cx = incBy(GetU(x, i), c); for (usz j = 0; RARE(j < c); j++) HARR_ADDA(r, cx); } } decG(w); decG(x); return withFill(HARR_FV(r), xf); } if (isArr(x) && RNK(x)==1 && q_i32(w)) { usz xia = IA(x); i32 wv = o2i(w); if (wv<=0) { if (wv<0) thrM("/: 𝕨 cannot be negative"); return taga(arr_shVec(TI(x,slice)(x, 0, 0))); } if (TI(x,elType)==el_i32) { i32* xp = i32any_ptr(x); i32* rp; B r = m_i32arrv(&rp, xia*wv); for (usz i = 0; i < xia; i++) { for (i64 j = 0; j < wv; j++) *rp++ = xp[i]; } decG(x); return r; } else { SLOW2("𝕨/𝕩", w, x); B xf = getFillQ(x); HArr_p r = m_harrUv(xia*wv); SGetU(x) for (usz i = 0; i < xia; i++) { B cx = incBy(GetU(x, i), wv); for (i64 j = 0; j < wv; j++) *r.a++ = cx; } decG(x); return withFill(r.b, xf); } } return c2(rt_slash, w, x); } B slash_im(B t, B x) { if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be an array"); u8 xe = TI(x,elType); usz xia = IA(x); if (xia==0) { decG(x); return emptyIVec(); } switch(xe) { default: UD; case el_bit: { usz sum = bit_sum(bitarr_ptr(x), xia); usz ria = 1 + (sum>0); f64* rp; B r = m_f64arrv(&rp, ria); rp[sum>0] = sum; rp[0] = xia - sum; decG(x); return num_squeeze(r); } #define CASE_SMALL(N) \ case el_i##N: { \ i##N* xp = i##N##any_ptr(x); \ usz m=1<xp[a-1]) a++; \ max=xp[a-1]; \ if (a==xia) { /* Sorted unique argument */ \ usz ria = max + 1; \ u64* rp; r = m_bitarrv(&rp, ria); \ for (usz i=0; imax) max=c; } \ if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \ usz ria = max+1; \ i##N* rp; r = m_i##N##arrv(&rp, ria); for (usz i=0; im/2) thrM("/⁼: Argument cannot contain negative numbers"); \ i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; imax?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); } usz ria = max+1; if (i==xia) { u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; imax?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); } usz ria = max+1; if (ria==0) thrOOM(); if (i==xia) { u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; ia; } usz i,j; B r; i64 max=-1; for (i = 0; i < xia; i++) { i64 c=o2i64(xp[i]); if (c<=max) break; max=c; } for (j = i; j < xia; j++) { i64 c=o2i64(xp[j]); max=c>max?c:max; if (c<0) thrM("/⁼: Argument cannot contain negative numbers"); } if (max > USZ_MAX-1) thrOOM(); usz ria = max+1; if (i==xia) { u64* rp; r = m_bitarrv(&rp, ria); for (usz i=0; iim = slash_im; }