From 04d0191d1f42193f421b90caa3bef24c8e3d96c3 Mon Sep 17 00:00:00 2001 From: dzaima Date: Wed, 21 May 2025 02:09:11 +0300 Subject: [PATCH] =?UTF-8?q?handle=20v=C2=A8=E2=8C=BE(l=E2=8A=B8/)x=20with?= =?UTF-8?q?=20non-boolean=20l=20with=20fast=20path?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/slash.c | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 4430b6d6..2e35a6ea 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -1061,6 +1061,7 @@ AnyArr m_anyarrc(u8 re, B x) { // consumes x; returns new array with given eleme return (AnyArr) {r, rp}; } +B ne_c2(B,B,B); B slash_ucw(B t, B o, B w, B x) { if (isAtm(w) || isAtm(x) || RNK(w)!=1 || RNK(x)!=1 || IA(w)!=IA(x)) { base: @@ -1071,13 +1072,19 @@ B slash_ucw(B t, B o, B w, B x) { u8 we = TI(w,elType); if (we != el_bit) { w = squeeze_numTry(w, &we); - if (!elInt(we)) goto base; + if (we != el_bit) { + if (!elNum(we)) goto base; + i64 bounds[2]; + if (!getRange_fns[we](tyany_ptr(w), bounds, IA(w)) || bounds[0]<0) { + usum(w); + thrOOM(); + } + } } // (c ; C˙)¨⌾(w⊸/) x #if SINGELI_SIMD if (isFun(o) && TY(o)==t_md1D) { - if (we != el_bit) goto notConstEach; // TODO could do w↩w≠0 after range checking u8 xe = TI(x,elType); if (xe==el_B) goto notConstEach; @@ -1087,6 +1094,13 @@ B slash_ucw(B t, B o, B w, B x) { B c; if (!toConstant(f, &c)) goto notConstEach; + if (we != el_bit) { + // relies + w = C2(ne, w, m_f64(0)); + assert(TI(w,elType)==el_bit); + we = el_bit; + } + u8 ce = selfElType(c); u8 re = el_or(ce,xe); // can be el_B @@ -1132,7 +1146,7 @@ B slash_ucw(B t, B o, B w, B x) { if (isAtm(rep) || RNK(rep)!=1 || IA(rep) != argIA) thrF("𝔽⌾(a⊸/)𝕩: 𝔽 must return an array with the same shape as its input (expected ⟨%s⟩, got %H)", argIA, rep); u8 re = el_or(TI(x,elType), TI(rep,elType)); MAKE_MUT_INIT(r, ia, re? re : 1); - usz repI = 0; + ux repI = 0; if (we==el_bit && re!=el_B) { u64* d = bitany_ptr(w); void* rp = r->a; @@ -1176,10 +1190,10 @@ B slash_ucw(B t, B o, B w, B x) { } SGetU(rep) for (usz i = 0; i < ia; i++) { - i32 cw = o2iG(GetU(w, i)); + ux cw = o2u64G(GetU(w, i)); if (cw) { B cr = Get(rep,repI); - if (CHECK_VALID) for (i32 j = 1; j < cw; j++) if (!compatible(GetU(rep,repI+j), cr)) { mut_pfree(r,i); thrM("𝔽⌾(a⊸/): Incompatible result elements"); } + if (CHECK_VALID) for (ux j = 1; j < cw; j++) if (!compatible(GetU(rep,repI+j), cr)) { mut_pfree(r,i); thrM("𝔽⌾(a⊸/): Incompatible result elements"); } mut_setG(r, i, cr); repI+= cw; } else mut_setG(r, i, Get(x,i));