diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 45b73c12..49f2f8c9 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -1003,43 +1003,65 @@ B slash_ucw(B t, B o, B w, B x) { usz argIA = IA(arg); B rep = c1(o, arg); if (isAtm(rep) || RNK(rep)!=1 || IA(rep) != argIA) thrF("𝔽⌾(a⊸/)𝕩: 𝔽 must return an array with the same shape as its input (expected ⟨%s⟩, got %H)", argIA, rep); - MAKE_MUT_INIT(r, ia, el_or(TI(x,elType), TI(rep,elType))); - SGet(x) - SGet(rep) + u8 re = el_or(TI(x,elType), TI(rep,elType)); + MAKE_MUT_INIT(r, ia, re? re : 1); usz repI = 0; - if (TY(w) == t_bitarr) { + bool wb = TY(w) == t_bitarr; + if (wb && re!=el_B) { u64* d = bitarr_ptr(w); - if (elInt(TI(x,elType)) && elInt(TI(rep,elType))) { - if (r->fns->elType!=el_i32) mut_to(r, el_i32); - i32* rp = r->a; - x = toI32Any(x); i32* xp = i32any_ptr(x); - rep = toI32Any(rep); i32* np = i32any_ptr(rep); - for (usz i = 0; i < ia; i++) { - bool v = bitp_get(d, i); - i32 nc = np[repI]; - i32 xc = xp[i]; - rp[i] = v? nc : xc; - repI+= v; - } - } else { - MUTG_INIT(r); - for (usz i = 0; i < ia; i++) mut_setG(r, i, bitp_get(d, i)? Get(rep,repI++) : Get(x,i)); + void* rp = r->a; + + #define IMPL(T) do { \ + T* xp = tyany_ptr(x); \ + T* np = tyany_ptr(rep); \ + NOUNROLL for (usz i = 0; i < ia; i++) { \ + bool v = bitp_get(d, i); \ + ((T*)rp)[i] = *(v? np+repI : xp+i); \ + repI+= v; \ + } \ + goto dec_ret; \ + } while(0) + + #define WIDEN_GO(UT, T, B) x = to##UT##Any(x); rep = to##UT##Any(rep); goto bit_##B + + switch (re) { + case el_bit: + case el_i8: x = toI8Any(x); rep = toI8Any(rep); goto bit_u8; + case el_c8: x = toC8Any(x); rep = toC8Any(rep); goto bit_u8; + case el_i16: x = toI16Any(x); rep = toI16Any(rep); goto bit_u16; + case el_c16: x = toC16Any(x); rep = toC16Any(rep); goto bit_u16; + case el_i32: x = toI32Any(x); rep = toI32Any(rep); goto bit_u32; + case el_c32: x = toC32Any(x); rep = toC32Any(rep); goto bit_u32; + case el_f64: x = toF64Any(x); rep = toF64Any(rep); goto bit_u64; } + bit_u8: IMPL(u8); + bit_u16: IMPL(u16); + bit_u32: IMPL(u32); + bit_u64: IMPL(u64); + #undef IMPL + #undef WIDEN_GO } else { - SGetU(rep) - MUTG_INIT(r); - for (usz i = 0; i < ia; i++) { - i32 cw = o2iG(GetU(w, i)); - if (cw) { - B cr = Get(rep,repI); - if (CHECK_VALID) for (i32 j = 1; j < cw; j++) if (!equal(GetU(rep,repI+j), cr)) { mut_pfree(r,i); thrM("𝔽⌾(a⊸/): Incompatible result elements"); } - mut_setG(r, i, cr); - repI+= cw; - } else mut_setG(r, i, Get(x,i)); + SGet(x) SGet(rep) MUTG_INIT(r); + if (wb) { + u64* d = bitarr_ptr(w); + for (usz i = 0; i < ia; i++) mut_setG(r, i, bitp_get(d, i)? Get(rep,repI++) : Get(x,i)); + } else { + SGetU(rep) + for (usz i = 0; i < ia; i++) { + i32 cw = o2iG(GetU(w, i)); + if (cw) { + B cr = Get(rep,repI); + if (CHECK_VALID) for (i32 j = 1; j < cw; j++) if (!equal(GetU(rep,repI+j), cr)) { mut_pfree(r,i); thrM("𝔽⌾(a⊸/): Incompatible result elements"); } + mut_setG(r, i, cr); + repI+= cw; + } else mut_setG(r, i, Get(x,i)); + } } } + dec_ret:; decG(w); decG(rep); - return mut_fcd(r, x); + B rb = mut_fcd(r, x); + return re==0? taga(cpyBitArr(rb)) : rb; } void slash_init(void) {