// Fold (´) and Insert (˝) // Fold optimized operands: // ⊣⊢ on all types // +-∧∨=≠ and synonyms on booleans // ≤<>≥ on booleans, monadic only, with a search // +⌈⌊× on numbers // Integer +: sum blocks associatively as long as sum can't exceed +-2⋆53 // COULD implement fast numeric -´ // ∨ on boolean-valued integers, stopping at 1 // •math.Sum: +´ with faster and more precise SIMD code for i32, f64 // Insert with rank (˝˘ or ˝⎉k), or fold on flat array // SHOULD optimize dyadic insert with rank // Length 1, ⊣⊢: implemented as ⊏˘ // ∾˝ with rank: reshape argument // Arithmetic on empty: reshape identity // Boolean operand, cell size 1: // +: popcount // Rows length ≤64: extract rows, popcount each // TRIED scan-based version, faster for width 3 only // COULD have bit-twiddling versions of +˝˘ on very short rows // >64: popcount, boundary corrections (clang auto-vectorizes) // ∨∧≠= and synonyms, rows <64: extract from scan or windowed op // Dedicated auto-vectorizing code for sizes 2, 4 // Extract is generic with multiplies or BMI2 (SHOULD add SIMD) // COULD use CLMUL for faster windowed ≠ // ∨∧ and synonyms: // Multiples of 8 <256: comparison function, recurse if needed // Rows 6424: read word, mask, last bit of popcount // Rows l<64: word-based scan, correct with SWAR // ≥64: xor words (auto-vectorizes), correct boundary, popcount // COULD implement boolean -˝˘ with xor, +˝˘, offset // Arithmetic on rank 2, short rows: transpose then insert, blocked // SHOULD extend transpose-insert code to any frame and cell rank // SHOULD have dedicated +⌊⌈ insert with rank #include "../core.h" #include "../builtins.h" #include "../utils/mut.h" #include "../utils/calls.h" #if SINGELI extern uint64_t* const si_spaced_masks; #define get_spaced_mask(i) si_spaced_masks[i-1] #define SINGELI_FILE fold #include "../utils/includeSingeli.h" #endif static u64 xor_words(u64* x, u64 l) { u64 r = 0; for (u64 i = 0; i < l; i++) r^= x[i]; return r; } static bool fold_ne(u64* x, u64 am) { u64 r = xor_words(x, am>>6); if (am&63) r^= x[am>>6]<<(64-am & 63); return POPC(r) & 1; } static i64 bit_diff(u64* x, u64 am) { i64 r = 0; u64 a = 0xAAAAAAAAAAAAAAAA; for (u64 i = 0; i < (am>>6); i++) r+= POPC(x[i]^a); if (am&63) r+= POPC((x[am>>6]^a)<<(64-am & 63)); return r - (i64)(am/2); } // It's safe to sum a block of integers as long as the current total // is far enough from +-1ull<<53 (and integer, in dyadic fold). static const usz sum_small_max = 1<<16; #define DEF_INT_SUM(T,W,M,A) \ static i64 sum_small_##T(void* xv, usz ia) { \ i##A s=0; for (usz i=0; i0 && -lim<=c && c<=lim) { \ usz e=i0; i0=(i0-1)&~(b-1); \ c+=sum_small_##T(xp+i0, e-i0); \ } \ r = c; \ } \ while (i0--) r+=xp[i0]; \ return r; \ } DEF_INT_SUM(i8 ,8 ,32,32) DEF_INT_SUM(i16,16,32,32) DEF_INT_SUM(i32,32,52,64) #undef DEF_SUM static f64 sum_f64(void* xv, usz i, f64 r) { while (i--) r += ((f64*)xv)[i]; return r; } static i64 (*const sum_small_fns[])(void*, usz) = { sum_small_i8, sum_small_i16, sum_small_i32 }; static f64 (*const sum_fns[])(void*, usz, f64) = { sum_i8, sum_i16, sum_i32, sum_f64 }; B sum_c1(B t, B x) { if (isAtm(x) || RNK(x)!=1) thrF("•math.Sum𝕩: 𝕩 must be a list (%H ≡ ≢𝕩)", x); usz ia = IA(x); if (ia==0) { decG(x); return m_f64(0); } u8 xe = TI(x,elType); if (!elNum(xe)) { x = any_squeeze(x); xe = TI(x,elType); if (!elNum(xe)) thrF("•math.Sum𝕩: 𝕩 elements must be numbers", x); } f64 r; void* xv = tyany_ptr(x); if (xe == el_bit) { r = bit_sum(xv, ia); } else if (xe <= el_i32) { u8 sel = xe - el_i8; i64 s = 0; r = 0; i64 m = 1ull<<48; usz b = sum_small_max; for (usz i=0; i= m) { r+=m; s-=m; } if (s <= -m) { r-=m; s+=m; } } r += s; } else { #if SINGELI r = si_sum_f64(xv, ia); #else r=0; for (usz i=0; i) } DEF_MIN_MAX(i8) DEF_MIN_MAX(i16) DEF_MIN_MAX(i32) #if SINGELI static f64 min_f64(void* xv, usz ia) { return si_fold_min_f64(xv,ia); } static f64 max_f64(void* xv, usz ia) { return si_fold_max_f64(xv,ia); } #else DEF_MIN_MAX(f64) #endif #undef DEF_MIN_MAX #undef MIN_MAX static f64 (*const min_fns[])(void*, usz) = { min_i8, min_i16, min_i32, min_f64 }; static f64 (*const max_fns[])(void*, usz) = { max_i8, max_i16, max_i32, max_f64 }; B fold_c1(Md1D* d, B x) { B f = d->f; if (isAtm(x) || RNK(x)!=1) thrF("𝔽´𝕩: 𝕩 must be a list (%H ≡ ≢𝕩)", x); usz ia = IA(x); if (ia<=2) { if (ia==2) { SGet(x) B x0 = Get(x,0); B x1 = Get(x,1); decG(x); return c2(f, x0, x1); } else if (ia==1) { return TO_GET(x,0); } else { decG(x); if (isFun(f)) { B r = TI(f,identity)(f); if (!q_N(r)) return r; } thrM("𝔽´𝕩: Identity not found"); } } if (RARE(!isFun(f))) { decG(x); if (isMd(f)) thrM("Calling a modifier"); return inc(f); } u8 xe = TI(x,elType); if (RTID(f) != RTID_NONE) { u8 rtid = RTID(f); if (rtid==n_ltack) return TO_GET(x, 0); if (rtid==n_rtack) return TO_GET(x, ia-1); if (xe>el_f64) goto base; if (xe==el_bit) { u64* xp = bitany_ptr(x); f64 r; switch (rtid) { default: goto base; case n_mul: case n_and:case n_floor: r = bit_has (xp, ia, 0) ^ 1; break; case n_or: case n_ceil: r = bit_has (xp, ia, 1); break; case n_add: r = bit_sum (xp, ia); break; case n_sub: r = bit_diff(xp, ia); break; case n_ne: r = fold_ne (xp, ia); break; case n_eq: r = fold_ne (xp, ia) ^ (1&~ia); break; case n_lt: r = bit_find(xp, ia, 1) == ia-1; break; case n_le: r = bit_find(xp, ia, 0) != ia-1; break; case n_gt: r = bit_find(xp, ia, 0) & 1; break; case n_ge: r =~bit_find(xp, ia, 1) & 1; break; } decG(x); return m_f64(r); } if (rtid==n_add) { // + void *xv = tyany_ptr(x); bool small = xe<=el_i32 & ia<=sum_small_max; u8 sel = xe - el_i8; f64 r = small ? sum_small_fns[sel](xv, ia) : sum_fns[sel](xv, ia, 0); decG(x); return m_f64(r); } if (rtid==n_floor) { f64 r=min_fns[xe-el_i8](tyany_ptr(x), ia); decG(x); return m_f64(r); } // ⌊ if (rtid==n_ceil ) { f64 r=max_fns[xe-el_i8](tyany_ptr(x), ia); decG(x); return m_f64(r); } // ⌈ if (rtid==n_mul | rtid==n_and) { // ×/∧ void *xv = tyany_ptr(x); assert(xe >= el_i8); u8 sel = xe - el_i8; f64 r = xe<=el_i32 ? prod_int_fns[sel](xv, ia, 1) : prod_fns[sel](xv, ia, 1); decG(x); return m_f64(r); } if (rtid==n_or) { // ∨ if (xe==el_i8 ) { i8* xp = i8any_ptr (x); usz i=ia; while (i--) { i8 c=xp[i]; if (c==1) break; if (c!=0) goto base; } decG(x); return m_i32(i+1 > 0); } if (xe==el_i16) { i16* xp = i16any_ptr(x); usz i=ia; while (i--) { i16 c=xp[i]; if (c==1) break; if (c!=0) goto base; } decG(x); return m_i32(i+1 > 0); } if (xe==el_i32) { i32* xp = i32any_ptr(x); usz i=ia; while (i--) { i32 c=xp[i]; if (c==1) break; if (c!=0) goto base; } decG(x); return m_i32(i+1 > 0); } } } base:; SLOW2("𝕎´ 𝕩", f, x); SGet(x) FC2 fc2 = c2fn(f); B c; if (TI(x,elType)==el_i32) { i32* xp = i32any_ptr(x); c = m_i32(xp[ia-1]); for (usz i = ia-1; i>0; i--) c = fc2(f, m_i32(xp[i-1]), c); } else { c = Get(x, ia-1); for (usz i = ia-1; i>0; i--) c = fc2(f, Get(x, i-1), c); } decG(x); return c; } B fold_c2(Md1D* d, B w, B x) { B f = d->f; if (isAtm(x) || RNK(x)!=1) thrF("𝕨𝔽´𝕩: 𝕩 must be a list (%H ≡ ≢𝕩)", x); usz ia = IA(x); if (RARE(ia==0)) { decG(x); return w; } if (RARE(!isFun(f))) { dec(w); decG(x); if (isMd(f)) thrM("Calling a modifier"); return inc(f); } u8 xe = TI(x,elType); if (RTID(f) != RTID_NONE) { u8 rtid = RTID(f); if (rtid==n_ltack) { dec(w); return TO_GET(x, 0); } if (rtid==n_rtack) { decG(x); return w; } if (!isF64(w) || xe>el_f64) goto base; f64 wf = o2fG(w); if (xe==el_bit) { i32 wi = wf; if (wi!=wf) goto base; u64* xp = bitany_ptr(x); if (rtid==n_add) { B r = m_f64(wi + bit_sum (xp, ia)); decG(x); return r; } if (rtid==n_sub) { B r = m_f64((ia&1?-wi:wi) + bit_diff(xp, ia)); decG(x); return r; } if (wi!=(wi&1)) goto base; if (rtid==n_and | rtid==n_mul | rtid==n_floor) { B r = m_i32(wi && !bit_has(xp, ia, 0)); decG(x); return r; } if (rtid==n_or | rtid==n_ceil ) { B r = m_i32(wi || bit_has(xp, ia, 1)); decG(x); return r; } if (rtid==n_ne) { bool r=wi^fold_ne(xp, ia) ; decG(x); return m_i32(r); } if (rtid==n_eq) { bool r=wi^fold_ne(xp, ia) ^ (1&ia); decG(x); return m_i32(r); } goto base; } if (rtid==n_add) { // + u8 sel = xe - el_i8; f64 r = sum_fns[sel](tyany_ptr(x), ia, wf); decG(x); return m_f64(r); } if (rtid==n_floor) { f64 r=wf; f64 m=min_fns[xe-el_i8](tyany_ptr(x), ia); if (mr) r=m; decG(x); return m_f64(r); } // ⌈ i32 wi = wf; if (rtid==n_mul | rtid==n_and) { // ×/∧ void *xv = tyany_ptr(x); bool isint = xe<=el_i32 && wi==wf; u8 sel = xe - el_i8; f64 r = isint ? prod_int_fns[sel](xv, ia, wi) : prod_fns[sel](xv, ia, wf); decG(x); return m_f64(r); } if (rtid==n_or && (wi&1)==wf) { // ∨ if (xe==el_i8 ) { i8* xp = i8any_ptr (x); usz i=ia; if (!wi) while (i--) { i8 c=xp[i]; if (c==1) break; if (c!=0) goto base; } decG(x); return m_i32(i+1 > 0); } if (xe==el_i16) { i16* xp = i16any_ptr(x); usz i=ia; if (!wi) while (i--) { i16 c=xp[i]; if (c==1) break; if (c!=0) goto base; } decG(x); return m_i32(i+1 > 0); } if (xe==el_i32) { i32* xp = i32any_ptr(x); usz i=ia; if (!wi) while (i--) { i32 c=xp[i]; if (c==1) break; if (c!=0) goto base; } decG(x); return m_i32(i+1 > 0); } } } base:; SLOW3("𝕨 F´ 𝕩", w, x, f); B c = w; SGet(x) FC2 fc2 = c2fn(f); for (usz i = ia; i>0; i--) c = fc2(f, Get(x, i-1), c); decG(x); return c; } NOINLINE i64 bit_sum(u64* x, u64 am) { i64 r = 0; for (u64 i = 0; i < (am>>6); i++) r+= POPC(x[i]); if (am&63) r+= POPC(x[am>>6]<<(64-am & 63)); return r; } NOINLINE static u64 usum_generic(B x, usz xia) { SGetU(x) u64 r = 0; for (usz i = 0; i < xia; i++) { u64 c = o2u64(GetU(x,i)); if (addOn(r,c)) thrM("Sum too big"); } return r; } u64 usum(B x) { // doesn't consume; will error on non-integers, or elements <0, or if sum overflows u64 assert(isArr(x)); u64 r = 0; usz xia = IA(x); u8 xe = TI(x,elType); if (xe==el_bit) return bit_sum(bitany_ptr(x), xia); else if (xe==el_i8 ) { i8* p = i8any_ptr (x); i8 m=0; for (usz i = 0; i < xia; ) { usz b=1<< 8; i16 s=0; for (usz e = xia-if; ur xr; if (isAtm(x) || (xr=RNK(x))==0) thrM("𝔽˝𝕩: 𝕩 must have rank at least 1"); usz len = *SH(x); if (len==0) { if (isFun(f)) { B id = TI(f,identity)(f); if (!q_N(id)) { if (isArr(id)) { decG(x); return id; } // arithmetic with table usz* xsh = SH(x); Arr* r = reshape_one(shProd(xsh, 1, xr), id); if (xr == 2) arr_shVec(r); else { usz* rsh = arr_shAlloc(r, xr-1); shcpy(rsh, xsh+1, xr-1); } decG(x); return taga(r); } else if (RTID(f) == n_join) { if (xr <= 1) thrM("𝔽˝𝕩: Identity does not exist"); goto join; } } thrM("𝔽˝𝕩: Identity not found"); } if (len==1) return C1(select, x); if (RARE(!isFun(f))) { decG(x); if (isMd(f)) thrM("Calling a modifier"); return inc(f); } if (isPervasiveDyExt(f)) { if (xr==1) return m_unit(fold_c1(d, x)); if (len==IA(x)) { B r = m_vec1(fold_c1(d, C1(shape, x))); ur rr = xr - 1; if (rr > 1) { ShArr* rsh = m_shArr(rr); PLAINLOOP for (ur i=0; ia[i] = 1; arr_shReplace(a(r), rr, rsh); } return r; } } if (RTID(f) != RTID_NONE) { u8 rtid = RTID(f); if (rtid==n_ltack) return C1(select, x); if (rtid==n_rtack) return C2(select, m_f64(-1), x); if (rtid==n_join) { join:; if (xr==1) return x; ShArr* rsh; if (xr>2) { rsh = m_shArr(xr-1); usz* xsh = SH(x); shcpy(rsh->a+1, xsh+2, xr-2); rsh->a[0] = xsh[0] * xsh[1]; } Arr* r = TI(x,slice)(x, 0, IA(x)); if (xr>2) arr_shSetUG(r, xr-1, rsh); else arr_shVec(r); return taga(r); } } return insert_base(f, x, 0, m_f64(0)); } B insert_c2(Md1D* d, B w, B x) { B f = d->f; ur xr; if (isAtm(x) || (xr=RNK(x))==0) thrM("𝕨˝𝕩: 𝕩 must have rank at least 1"); usz len = *SH(x); if (len==0) { decG(x); return w; } if (RARE(!isFun(f))) { dec(w); decG(x); if (isMd(f)) thrM("Calling a modifier"); return inc(f); } if (isPervasiveDyExt(f) && len==IA(x)) { // 1-element arrays are always conformable // final rank is higher of w, cell rank of x ur rr = xr - 1; if (isArr(w)) { if (IA(w) != 1) goto skip; ur wr = RNK(w); if (wr>rr) rr = wr; w = TO_GET(w, 0); } if (xr > 1) x = C1(shape, x); B r = m_unit(fold_c2(d, w, x)); if (rr > 0) { if (rr == 1) arr_shVec(a(r)); else { ShArr* rsh = m_shArr(rr); PLAINLOOP for (ur i=0; ia[i] = 1; arr_shReplace(a(r), rr, rsh); } } return r; skip:; } if (RTID(f) != RTID_NONE) { u8 rtid = RTID(f); if (rtid==n_ltack) { dec(w); return C1(select, x); } if (rtid==n_rtack) { decG(x); return w; } } return insert_base(f, x, 1, w); } B insert_cells_join(B x, usz* xsh, ur cr, ur k) { assert(k > 0); if (cr <= 1) { if (xsh[k]==0) thrM("˝: Identity does not exist"); return x; } ur rr = k+cr-1; // result rank (>1) ShArr* rsh; rsh = m_shArr(rr); shcpy(rsh->a, xsh, k); rsh->a[k] = xsh[k] * xsh[k+1]; shcpy(rsh->a+k+1, xsh+k+2, cr-2); Arr* r = TI(x,slice)(x, 0, IA(x)); arr_shSetUG(r, rr, rsh); return taga(r); } B insert_cells_identity(B x, B f, usz* xsh, ur xr, ur k, u8 rtid) { B id; if (!isFun(f) || q_N(id = TI(f,identity)(f))) thrF("%c: Identity not found", rtid==n_fold? U'´' : U'˝'); bool cs = !isArr(id); // if x cell shape is used (not table) usz ria = shProd(xsh, 0, k); if (cs) ria*= shProd(xsh, k+1, xr); Arr* r = reshape_one(ria, id); ur rr = cs? xr-1 : k; if (rr == 1) arr_shVec(r); else { usz* rsh = arr_shAlloc(r, rr); shcpy(rsh, xsh, k); if (cs) shcpy(rsh+k, xsh+k+1, rr-k); } decG(x); return taga(r); } // Arithmetic fold/insert on -k-cells of flat array x // Return a vector regardless of argument shape B transp_c1(B, B); B fold_rows(Md1D* fd, B x, usz n, usz m) { assert(isArr(x) && IA(x)==n*m); // Target block size trying to avoid power-of-two lengths, from: // {𝕩/˜⌊´⊸= +˝˘ +˝¬∨`2|>⌊∘÷⟜2⍟(↕12) ⌊0.5+32÷˜𝕩÷⌜1+↕64} +⟜↕2⋆16 u64 block = (116053*8) >> arrTypeBitsLog(TY(x)); if (TI(x,elType)==el_bit || IA(x)/2 <= block) { if (RNK(x) > 2) { Arr* xc = cpyWithShape(x); ShArr* sh = m_shArr(2); sh->a[0] = n; sh->a[1] = m; x = taga(arr_shReplace(xc, 2, sh)); } x = C1(transp, x); return insert_c1(fd, x); } else { usz b = (block + m - 1) / m; // Normal block length usz b_max = b + b/4; // Last block max length MAKE_MUT(r, n); MUT_APPEND_INIT(r); BSS2A slice = TI(x,slice); for (usz i=0, im=0; i b_max) { incG(x); l = b; } usz sia = l * m; Arr* sl = slice(x, im, sia); usz* ssh = arr_shAlloc(sl, 2); ssh[0] = l; ssh[1] = m; B sr = insert_c1(fd, C1(transp, taga(sl))); MUT_APPEND(r, sr, 0, l); decG(sr); i += l; im += sia; } return mut_fv(r); } } B sum_rows_bit(B x, usz n, usz m) { u64* xp = bitany_ptr(x); if (m < 128) { if (m == 2) return bi_N; // Transpose is faster i8* rp; B r = m_i8arrv(&rp, n); if (m <= 64) { if (m%8 == 0) { usz k = m/8; u64 b = (m==64? 0 : 1ull<> (j%8); rp[i] = POPC(b & xw); } } else { // Row may not fit in an aligned word // Read a word containing the last bit, combine with saved bits u64 b = ~(~(u64)0 >> m); u64 prev = 0; for (usz i=0, j=m; i> (m-sh); } } } } else { // 64> (in%64)); rp[i] = s - o; j = jn+1; } } decG(x); return r; } else if (m < 1<<15) { i16* rp; B r = m_i16arrv(&rp, n); usz l = m/64; if (m%64==0) { for (usz i=0; i= j+l)); o = POPC(e >> (in%64)); rp[i] = s - o; j = jn+1; } } decG(x); return r; } else { return bi_N; } } // Fold n cells of size m, stride 1 // Return a vector regardless of argument shape, or bi_N if not handled B fold_rows_bit(Md1D* fd, B x, usz n, usz m) { assert(isArr(x) && TI(x,elType)==el_bit && IA(x)==n*m); if (RTID(fd->f) == RTID_NONE) return bi_N; u8 rtid = RTID(fd->f); if (rtid==n_add) return sum_rows_bit(x, n, m); #if SINGELI bool is_or = rtid==n_or |rtid==n_ceil; bool andor = rtid==n_and|rtid==n_floor|rtid==n_mul|is_or; if (rtid==n_ne|rtid==n_eq|andor) { if (n==0) { decG(x); return taga(arr_shVec(allZeroes(0))); } if (andor && m < 256) while (m%8 == 0) { usz f = CTZ(m|32); m >>= f; usz c = m*n; u64* yp; B y = m_bitarrv(&yp, c); u8 e = el_i8 + f-3; CmpASFn cmp = is_or ? CMP_AS_FN(ne, e) : CMP_AS_FN(eq, e); CMP_AS_CALL(cmp, yp, bitany_ptr(x), m_f64(is_or-1), c); decG(x); if (m==1) return y; x = y; } u64* xp = bitany_ptr(x); u64* rp; B r = m_bitarrv(&rp, n); if (andor) si_or_rows_bit(xp, rp, n, m, !is_or); else si_xor_rows_bit(xp, rp, n, m, rtid==n_eq); decG(x); return r; } #endif return bi_N; }