From 41785cb4bf319f7a5c8e49dbe3a42d241881e218 Mon Sep 17 00:00:00 2001 From: dzaima Date: Tue, 10 Sep 2024 01:45:56 +0300 Subject: [PATCH] =?UTF-8?q?Singeli=20const=C2=A8=E2=8C=BE(m=E2=8A=B8/)b?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/slash.c | 95 ++++++++++++++++++++++++++++++++-- src/singeli/c/arithdDispatch.c | 2 + src/singeli/src/bits.singeli | 15 ++++++ src/utils/bits.c | 10 ++++ test/cases/under.bqn | 11 ++++ 5 files changed, 129 insertions(+), 4 deletions(-) diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 7f6b42a8..914580e2 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -932,6 +932,40 @@ B slash_im(B t, B x) { decG(x); return r; } +#if SINGELI_SIMD + typedef void (*BlendArrScalarFn)(void* r, void* zero, u64 one, void* mask, u64 n); + extern INIT_GLOBAL BlendArrScalarFn* blendArrScalarFns; +#endif + +typedef struct { B b; void* data; } AnyArr; +AnyArr cpyElTypeArr(u8 re, B x) { // consumes x; returns new array with the given element type, with data & shape from x + Arr* a; + switch (re) { default: UD; + case el_bit: a = cpyBitArr(x); goto tyarr; + case el_i8: a = cpyI8Arr(x); goto tyarr; + case el_i16: a = cpyI16Arr(x); goto tyarr; + case el_i32: a = cpyI32Arr(x); goto tyarr; + case el_f64: a = cpyF64Arr(x); goto tyarr; + case el_c8: a = cpyC8Arr(x); goto tyarr; + case el_c16: a = cpyC16Arr(x); goto tyarr; + case el_c32: a = cpyC32Arr(x); goto tyarr; + case el_B: + a = cpyHArr(x); + return (AnyArr){taga(a), harrv_ptr(a)}; + } + tyarr: + return (AnyArr){taga(a), tyarrv_ptr((TyArr*)a)}; +} +AnyArr m_anyarrc(u8 re, B x) { // consumes x; returns new array with given element type; may start NOGC + if (re==el_B) { + HArr_p r = m_harrUc(x); + return (AnyArr) {r.b, r.a}; + } + B r; + void* rp = m_tyarrlbc(&r, elwBitLog(re), x, el2t(re)); + return (AnyArr) {r, rp}; +} + B slash_ucw(B t, B o, B w, B x) { if (isAtm(w) || isAtm(x) || RNK(w)!=1 || RNK(x)!=1 || IA(w)!=IA(x)) { base: @@ -939,10 +973,63 @@ B slash_ucw(B t, B o, B w, B x) { } usz ia = IA(x); SGetU(w) - if (TY(w) != t_bitarr) { + if (TI(w,elType) != el_bit) { w = num_squeezeChk(w); if (!elInt(TI(w,elType))) goto base; } + + // (c ; C˙)¨⌾(w⊸/) x + #if SINGELI_SIMD + if (isFun(o) && TY(o)==t_md1D) { + if (TI(w,elType) != el_bit) goto notConstEach; // TODO could do w↩w≠0 after range checking + u8 xe = TI(x,elType); + if (xe==el_B) goto notConstEach; + + Md1D* od = c(Md1D,o); + if (od->m1->flags-1 != n_each) goto notConstEach; + B f = od->f; + B c; + if (!toConstant(f, &c)) goto notConstEach; + + u8 ce = selfElType(c); + u8 re = el_or(ce,xe); // can be el_B + + B r; + if (isVal(c)) { + u64 sum = usum(w); + if (sum==0) decG(c); // TODO could return x; fills? + else incByG(c, sum-1); + } + + void* rp; + void* xp; + if (re != xe) { + AnyArr a = cpyElTypeArr(re, x); + x = r = incG(a.b); + xp = rp = a.data; + } else { + AnyArr a = m_anyarrc(re, x); + r = a.b; + rp = a.data; + xp = tyany_ptr(x); + } + + u64 cv; + if (elInt(re)) cv = o2iG(c); + else if (elChr(re)) cv = o2cG(c); + else { + assert(re==el_B || re==el_f64); + cv = re==el_f64? r_f64u(o2fG(c)) : r_Bu(c); + } + + blendArrScalarFns[elwBitLog(re)](rp, xp, cv, bitany_ptr(w), IA(x)); + NOGC_E; + decG(w); decG(x); + return r; + } + #endif + notConstEach:; + B arg = C2(slash, incG(w), incG(x)); usz argIA = IA(arg); B rep = c1(o, arg); @@ -950,7 +1037,7 @@ B slash_ucw(B t, B o, B w, B x) { u8 re = el_or(TI(x,elType), TI(rep,elType)); MAKE_MUT_INIT(r, ia, re? re : 1); usz repI = 0; - bool wb = TY(w) == t_bitarr; + bool wb = TI(w,elType) == el_bit; if (wb && re!=el_B) { u64* d = bitany_ptr(w); void* rp = r->a; @@ -973,7 +1060,7 @@ B slash_ucw(B t, B o, B w, B x) { ((T*)rp)[i] = *(v? np+repI : xp+i); \ repI+= v; \ } \ - goto dec_ret; \ + goto mut_dec_ret; \ } while(0) bit_u8: IMPL(u8); @@ -1000,7 +1087,7 @@ B slash_ucw(B t, B o, B w, B x) { } } } - dec_ret:; + mut_dec_ret:; decG(w); decG(rep); B rb = mut_fcd(r, x); return re==0? taga(cpyBitArr(rb)) : rb; diff --git a/src/singeli/c/arithdDispatch.c b/src/singeli/c/arithdDispatch.c index ea4a9f6b..7b255339 100644 --- a/src/singeli/c/arithdDispatch.c +++ b/src/singeli/c/arithdDispatch.c @@ -301,6 +301,8 @@ static NOINLINE B or_SA(B t, B w, B x) { return r; } +extern void (*const orAAu_bit_bit_bit)(void*,void*,void*,u64); // used in bits.c + #define SINGELI_FILE arTables #include "../../utils/includeSingeli.h" diff --git a/src/singeli/src/bits.singeli b/src/singeli/src/bits.singeli index f5be1e92..9234657c 100644 --- a/src/singeli/src/bits.singeli +++ b/src/singeli/src/bits.singeli @@ -22,6 +22,21 @@ def table{w} = each{bitsel_i{w, .}, tup{u8, u16, u32, u64}} exportT{'simd_bitsel', table{arch_defvw}} +fn blend_arr_scalar{E}(rp:*void, zero:*void, one0:u64, mask:*void, len:u64) : void = { + if (same{E,'!'}) { + fatal{'bad blend'} + } else if (E==u1) { + emit{void, 'blendArrScalarBits', rp, zero, cast_i{E, one0}, mask, len} + } else { + def bulk = arch_defvw / width{E} + def VT = [bulk]E + def one = VT**cast_i{E, one0} + @maskedLoop{bulk}(r in tup{VT,*E~~rp}, zero in tup{VT,*E~~zero}, mask in tup{'b',VT,mask} over i to len) r = homBlend{zero, one, mask} + } +} + +exportT{'si_blend_arr_scalar', each{blend_arr_scalar, tup{u1, '!', '!', u8, u16, u32, u64}}} + (if (has_sel) { fn bitwiden_n_8(src:*void, dst:*void, csz:ux, cam:ux) : void = { assert{cam>0} diff --git a/src/utils/bits.c b/src/utils/bits.c index 03e011ef..0d482ca9 100644 --- a/src/utils/bits.c +++ b/src/utils/bits.c @@ -4,8 +4,18 @@ #if SINGELI_SIMD + extern void (*const orAAu_bit_bit_bit)(void*,void*,void*,u64); + static void blendArrScalarBits(void* r, void* zero, bool one, void* mask, u64 n) { + if (one) orAAu_bit_bit_bit(r, zero, mask, n); + else CMP_AA_CALL(CMP_AA_FN(gt, el_bit), r, zero, mask, n); + } + #define SINGELI_FILE bits #include "../utils/includeSingeli.h" + + typedef void (*BlendArrScalarFn)(void* r, void* zero, u64 one, void* mask, u64 n); + INIT_GLOBAL BlendArrScalarFn* blendArrScalarFns = si_blend_arr_scalar; + INIT_GLOBAL BitSelFn* bitselFns = simd_bitsel; #else #define BITSEL_DEF(E) void bitsel_##E(void* rp, u64* bp, u64 e0i, u64 e1i, u64 ia) { for (usz i=0; i %% +‿÷‿×‿=‿+‿+ +!"/: Lengths of components of 𝕨 must match 𝕩 (6 ≠ 7)" % 0¨⌾(1‿0‿0‿0‿1‿1⊸/) 0‿1‿0‿1‿0‿1‿0 +4↑ (⋈⋈3)⌾(0‿1‿0⊸/) ↕⋈3 %% ⟨⋈0,⋈3,⋈2,0⟩ +4↑ ⊢⌾(0‿1‿0⊸/) ↕⋈3 %% ⟨⋈0,⋈1,⋈2,0⟩ +4↑ ⊢⌾(0‿0‿0⊸/) ↕⋈3 %% ⟨⋈0,⋈1,⋈2,0⟩ + # < ⊢⌾< 4 %% 4 (<5)⌾< 4 %% 5