From 96ca9092bad4fbd4632faf53ec283d4182ffa520 Mon Sep 17 00:00:00 2001 From: dzaima Date: Thu, 8 Sep 2022 00:37:44 +0300 Subject: [PATCH] attempt at better bit_sel dispatch --- src/builtins/arithd.c | 2 +- src/builtins/select.c | 2 +- src/core/stuff.h | 2 +- src/singeli/c/arithdDispatch.c | 4 +- src/utils/bits.c | 83 ++++++++++++++++++++++------------ 5 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/builtins/arithd.c b/src/builtins/arithd.c index 709a20bc..1aae1f2f 100644 --- a/src/builtins/arithd.c +++ b/src/builtins/arithd.c @@ -230,7 +230,7 @@ B floor_c2(B, B, B); bool h0=both || b0==0; if (h0) e0 = bitX? f(bi_N, inc(w), m_f64(0)) : f(bi_N, m_f64(0), inc(x)); bool h1=both || b0==1; if (h1) e1 = bitX? f(bi_N, w, m_f64(1)) : f(bi_N, m_f64(1), x); // non-bitarr arg has been consumed - B r = bit_sel(b, e0, h0, e1, h1); // and now the bitarr arg is consumed too + B r = bit_sel(b, e0, e1); // and now the bitarr arg is consumed too dec(e0); dec(e1); return r; } diff --git a/src/builtins/select.c b/src/builtins/select.c index f6f11984..5c3f48b6 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -136,7 +136,7 @@ B select_c2(B t, B w, B x) { } else { x1 = GetU(x,1); } - r = bit_sel(w, x0, true, x1, true); + r = bit_sel(w, x0, x1); decG(x); return withFill(r, xf); } diff --git a/src/core/stuff.h b/src/core/stuff.h index e8e8c9c1..d17685df 100644 --- a/src/core/stuff.h +++ b/src/core/stuff.h @@ -112,7 +112,7 @@ static bool eqShape(B w, B x) { assert(isArr(w)); assert(isArr(x)); return eqShPart(wsh, xsh, wr); } -B bit_sel(B b, B e0, bool h0, B e1, bool h1); // consumes b; h0/h1 represent whether the corresponding element _might_ be in the result (can be true if unknown) +B bit_sel(B b, B e0, B e1); // consumes b; b must be bitarr; b⊏e0‿e1 Arr* allZeroes(usz ia); Arr* allOnes(usz ia); B bit_negate(B x); // consumes diff --git a/src/singeli/c/arithdDispatch.c b/src/singeli/c/arithdDispatch.c index 651efd89..162ecec6 100644 --- a/src/singeli/c/arithdDispatch.c +++ b/src/singeli/c/arithdDispatch.c @@ -309,7 +309,7 @@ B dyArith_SA(DyTableSA* table, B w, B x) { bitsel: { B opts[2]; if (!table->ents[el_bit].bitsel(table, w, opts)) goto rec; - return bit_sel(x, opts[0], 1, opts[1], 1); + return bit_sel(x, opts[0], opts[1]); } } @@ -340,7 +340,7 @@ static NOINLINE B or_SA(B t, B w, B x) { if (LIKELY(TI(x,elType)==el_bit)) { bitsel: f64 wf = o2fG(w); - return bit_sel(x, m_f64(bqn_or(wf, 0)), 1, m_f64(bqn_or(wf, 1)), 1); + return bit_sel(x, m_f64(bqn_or(wf, 0)), m_f64(bqn_or(wf, 1))); } x = num_squeezeChk(x); if (TI(x,elType)==el_bit) goto bitsel; diff --git a/src/utils/bits.c b/src/utils/bits.c index e58b8d4e..2a0b2df7 100644 --- a/src/utils/bits.c +++ b/src/utils/bits.c @@ -3,44 +3,69 @@ NOINLINE Arr* allZeroes(usz ia) { u64* rp; Arr* r = m_bitarrp(&rp, ia); for (usz i = 0; i < BIT_N(ia); i++) rp[i] = 0; return r; } NOINLINE Arr* allOnes (usz ia) { u64* rp; Arr* r = m_bitarrp(&rp, ia); for (usz i = 0; i < BIT_N(ia); i++) rp[i] = ~0ULL; return r; } -NOINLINE B bit_sel(B b, B e0, bool h0, B e1, bool h1) { +NOINLINE B bit_sel(B b, B e0, B e1) { u8 t0 = selfElType(e0); - u8 t1 = selfElType(e1); - if (!h0) t0=t1; // TODO just do separate impls for !h0 and !h1 - if (!h1) t1=t0; u64* bp = bitarr_ptr(b); usz ia = IA(b); - if (elNum(t0) && elNum(t1)) { B r; - f64 f0 = o2fG(e0); i32 i0 = f0; - f64 f1 = o2fG(e1); i32 i1 = f1; - u8 tM = t0>t1? t0 : t1; - if (tM==el_bit) { - if (i0) { - if (i1) { Arr* a = allOnes(ia); arr_shCopy(a, b); r = taga(a); } - else return bit_negate(b); - } else { - if (i1) return b; - else { Arr* a = allZeroes(ia); arr_shCopy(a, b); r = taga(a); } + B r; + { + u8 type, width; + u32 e0i, e1i; + f64 e0f, e1f; + if (elNum(t0) && isF64(e1)) { + f64 f0 = o2fG(e0); + f64 f1 = o2fG(e1); + switch (t0) { default: UD; + case el_bit: if (f1==0||f1==1) goto t_bit; + case el_i8: if (q_fi8(f1)) goto t_i8; if (q_fi16(f1)) goto t_i16; if (q_fi32(f1)) goto t_i32; goto t_f64; // not using fallthrough to allow deduplicating float→int conversion + case el_i16: if (q_fi16(f1)) goto t_i16; if (q_fi32(f1)) goto t_i32; goto t_f64; + case el_i32: if (q_fi32(f1)) goto t_i32; goto t_f64; + case el_f64: goto t_f64; } + t_bit: + if (f0) { + if (f1) { Arr* a = allOnes(ia); arr_shCopy(a, b); r = taga(a); goto dec_ret; } + else return bit_negate(b); + } else { + if (f1) return b; + else { Arr* a = allZeroes(ia); arr_shCopy(a, b); r = taga(a); goto dec_ret; } + } + t_i8: type=t_i8arr; width=0; e0i=( u8)( i8)f0; e1i=( u8)( i8)f1; goto sel; + t_i16: type=t_i16arr; width=1; e0i=(u16)(i16)f0; e1i=(u16)(i16)f1; goto sel; + t_i32: type=t_i32arr; width=2; e0i=(u32)(i32)f0; e1i=(u32)(i32)f1; goto sel; + t_f64: type=t_f64arr; width=3; e0f= f0; e1f= f1; goto sel; + + } else if (elChr(t0) && isC32(e1)) { + u32 u0 = o2cG(e0); u32 u1 = o2cG(e1); + switch(t0) { default: UD; + case el_c8: if (u1==( u8)u1) { type=t_c8arr; width=0; e0i=u0; e1i=u1; goto sel; } // else fallthrough + case el_c16: if (u1==(u16)u1) { type=t_c16arr; width=1; e0i=u0; e1i=u1; goto sel; } // else fallthrough + case el_c32: { type=t_c32arr; width=2; e0i=u0; e1i=u1; goto sel; } + } + } else goto slow; + + sel: + void* rp = m_tyarrlc(&r, width, b, type); + switch(width) { + case 0: for (usz i=0; i