uCBQN/src/builtins/select.c
2025-07-04 21:01:17 +03:00

1329 lines
45 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// First Cell and Select (⊏)
// First Cell is just a slice
// Select - 𝕨𝕩
// Complications in Select mostly come from range checks and negative 𝕨
// Atom or enclosed atom 𝕨 and rank-1 𝕩: make new array
// Atom or enclosed atom 𝕨 and high-rank 𝕩: slice
// Empty 𝕨: no selection
// Float or generic 𝕨: attempt to squeeze, go generic cell size path if stays float
// High-rank 𝕩 & boolean 𝕨: either widens 𝕨 to i8, or goes generic cell path
// SHOULD go a bit select path for small cells
// SHOULD reshape for 1=≠𝕩
// Boolean 𝕩 (cell size = 1 bit):
// 𝕨 larger than 𝕩: convert 𝕩 to i8, select, convert back
// Otherwise: select/shift bytes, reversed for fast writing
// TRIED pext, doesn't seem faster (mask built with shifts anyway)
// 𝕩 with cell sizes of 1, 2, 4, or 8 bytes:
// Small 𝕩 and i8 𝕨 with Singeli: use shuffles
// COULD try to squeeze 𝕨 to i8 if small enough 𝕩
// Boolean 𝕨: use bit_sel for blend or similar
// Integer 𝕨 with Singeli: fused wrap, range-check, and gather
// COULD try selecting from boolean with gather
// COULD detect <Skylake where gather is slow
// i32 𝕨: wrap, check, select one index at a time
// i8 and i16 𝕨: separate range check in blocks to vectorize it
// COULD wrap 𝕨 to a temp buffer
// COULD copy 𝕩 to a buffer indexable directly by positive and negative indices
// Generic cell size 𝕩:
// Computes a function that copies the necessary amount of bytes/bits
// Specializes over i8/i16/i32 𝕨
// Nested 𝕨:
// Recognizes a trailing element of a+↕b
// Converts remaining indices to single select indices via +⌜
// COULD have specialized select that skips OOB/negative checks
// Under Select - F⌾(i⊸⊏) 𝕩
// Specialized for rank-1 numeric 𝕩
// SHOULD apply to characters as well
// No longer needs to range-check but indices can be negative
// COULD convert negative indices before selection
// Must check collisions if CHECK_VALID; uses a byte set
// Sparse initialization if 𝕨 is much smaller than 𝕩
// COULD call Mark Firsts (∊) for very short 𝕨 to avoid allocation
// Select Cells - inds⊸⊏⎉1 𝕩
// Squeeze indices if too wide for given 𝕩
// Single index: (also used for monadic ⊏˘ ⊣˝˘ ⊢˝˘)
// Selecting a column of bits:
// Row size <64: extract as with fold-cells
// Selecting a column of 1, 2, 4, or 8-byte elements:
// Short cells: pack vectors from 𝕩, or blend and permute
// Long cells: dedicated scalar loop for each type
// Otherwise, loop with mutable copy
// Boolean indices:
// Short inds and short cells: Widen to i8
// Otherwise: bitsel call per cell
// SHOULD specialize wider input/output:
// AVX2 extract 16 bits from 128: vbroadcasti128+vpshufb+vpmullw
// AVX-512 extract 64 bits from 64: vpmultishiftqb/vpshufbitqmb
// potentially better options via transposing to allow a shuffle to reorder multiple rows
// 1, 2, 4 or 8-byte data elements & short cells & short index list:
// Split indices to available native shuffle width (e.g. 2‿1⊸⊏˘ n‿5⥊i16 → 2‿3‿0‿1⊸⊏˘ n‿10⥊i8)
// Repeat indices if using ≤0.5x of shuffle width (e.g. 0‿0‿2⊸⊏˘ n‿3⥊i8 → 0‿0‿2‿3‿3‿5⊸⊏˘ n‿6⥊i8)
// SHOULD disregard actual cell width if index range is small
// COULD merge to wider elements if indices are in runs (e.g. 0‿1‿6‿7⊸⊏˘ n‿10⥊i16 → 0‿3⊸⊏˘ n‿5⥊i32)
// COULD split into multiple index blocks
// Long inds / long cells:
// Direct call to select function per cell
// COULD have a more direct call that avoids overflow checking & wrapping
// COULD generate full list of indices via arith
// 1-element cells: use (≠inds)/⥊x after checking ∧´inds∊0‿¯1
// Used for ⌽⎉1
// SHOULD use for /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
#include "../core.h"
#include "../utils/talloc.h"
#include "../utils/mut.h"
#include "../utils/calls.h"
#include "../builtins.h"
#if SINGELI
#define SINGELI_FILE select
#include "../utils/includeSingeli.h"
typedef bool (*SimdSelectFn)(void* w0, void* x0, void* r0, u64 wl, u64 xl);
#define SIMD_SELECT(WE, XL) ({ AUTO we_=(WE); AUTO xl_=(XL); assert(we_>=el_i8 && we_<=el_i32 && xl_>=3 && xl_<=6); si_select_tab[4*(we_-el_i8)+xl_-3]; })
#endif
typedef void (*CFn)(void* r, ux rs, void* x, ux xs, ux data);
typedef struct {
CFn fn;
ux data;
ux mul;
} CFRes;
static void cf_0(void* r, ux rs, void* x, ux xs, ux d) { }
static void cf_1(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 1); }
static void cf_2(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 2); }
static void cf_3(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 3); }
static void cf_4(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 4); }
static CFn const cfs_0_4[] = {cf_0, cf_1, cf_2, cf_3, cf_4};
static void cf_8(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 8); }
static void cf_16(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 16); }
static void cf_5_7 (void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 4); memcpy(r+d, x+d, 4); }
static void cf_9_16 (void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 8); memcpy(r+d, x+d, 8); }
static void cf_17_24(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 16); memcpy(r+d, x+d, 8); }
static void cf_25_32(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 24); memcpy(r+d, x+d, 8); }
static void cf_arb(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, d); }
static void cfb_1(void* r, ux rs, void* x, ux xs, ux d) { bitp_set(r, rs, bitp_get(x, xs)); }
static void cfb_arb(void* r, ux rs, void* x, ux xs, ux d) { bit_cpy(r, rs, x, xs, d); }
NOINLINE CFRes cf_get(usz count, usz cszBits) {
if ((cszBits&7)==0) {
ux cszBytes = cszBits/8;
ux bytes = cszBytes * (ux)count;
if (bytes<5) return (CFRes){.mul=cszBytes, .fn = cfs_0_4[bytes]};
if (bytes<8) return (CFRes){.mul=cszBytes, .fn = cf_5_7, .data=bytes-4};
if (bytes==8) return (CFRes){.mul=cszBytes, .fn = cf_8};
if (bytes==16) return (CFRes){.mul=cszBytes, .fn = cf_16};
if (bytes<=16) return (CFRes){.mul=cszBytes, .fn = cf_9_16, .data=bytes-8};
if (bytes<=24) return (CFRes){.mul=cszBytes, .fn = cf_17_24, .data=bytes-8};
if (bytes<=32) return (CFRes){.mul=cszBytes, .fn = cf_25_32, .data=bytes-8};
return (CFRes){.mul=cszBytes, .fn = cf_arb, .data=bytes};
}
ux bits = count*(ux)cszBits;
if (bits==1) return (CFRes){.mul=cszBits, .fn = cfb_1};
else return (CFRes){.mul=cszBits, .fn = cfb_arb, .data=bits};
}
FORCE_INLINE void cf_call(CFRes f, void* r, ux rs, void* x, ux xs) {
f.fn(r, rs, x, xs, f.data);
}
extern GLOBAL B rt_select;
B select_c1(B t, B x) {
if (isAtm(x)) thrM("⊏𝕩: 𝕩 cannot be an atom");
ur xr = RNK(x);
if (xr==0) thrM("⊏𝕩: 𝕩 cannot be rank 0");
if (SH(x)[0]==0) thrF("⊏𝕩: 𝕩 shape cannot start with 0 (%H ≡ ≢𝕩)", x);
usz ia = shProd(SH(x), 1, xr);
Arr* r = TI(x,slice)(incG(x), 0, ia);
usz* sh = arr_shAlloc(r, xr-1);
if (sh) shcpy(sh, SH(x)+1, xr-1);
decG(x);
return taga(r);
}
B select_c2(B t, B w, B x);
static NOINLINE NORETURN void select_properError(B w, B x) {
select_c2(w, w, taga(cpyHArr(x)));
fatal("select_properError");
}
static NOINLINE B select_list_cell(usz wi, B x) { // guarantees returning new array
assert(isArr(x));
B xf = getFillR(x);
B xv = IGet(x, wi);
B rb;
if (numFill(xf) || chrFill(xf)) {
rb = m_unit(xv);
} else if (noFill(xf)) {
rb = m_hunit(xv);
} else {
Arr* r = m_fillarrp(1);
arr_shAtm(r);
fillarrv_ptr(r)[0] = xv;
fillarr_setFill(r, xf);
NOGC_E;
rb = taga(r);
}
decG(x);
return rb;
}
static NOINLINE void select_depth2_bad(B w, B x) {
usz wia = IA(w);
if (IA(x)==0 && wia>0) {
u8 we;
w = squeeze_numTry(w, &we, SQ_NUM);
if (elNum(we)) {
thrF("𝕨⊏𝕩: Indexing out-of-bounds (%B∊𝕨, %s≡≠𝕩)", IGetU(w,0), *SH(x));
}
}
if (RNK(w) > 1) thrF("𝕨⊏𝕩: Compound 𝕨 must have rank at most 1 (%H ≡ ≢𝕨)", w);
SGetU(w)
bool depth1 = depth(w)==1;
for (ux i = 0; i < wia; i++) {
B wc = GetU(w,i);
if (depth1) {
if (isAtm(wc) && !isNum(wc)) thrF("𝕨⊏𝕩: 𝕨 must be an array of numbers or list of such (𝕨 contained %S)", genericDesc(wc));
} else {
if (isAtm(wc)) thrF("𝕨⊏𝕩: 𝕨 must be an array of numbers or list of such (𝕨 contained both an array and %S)", genericDesc(wc));
}
}
}
static NOINLINE NORETURN void select_depth2_bad_inds(B cw, ux axis, B x) {
assert(axis < RNK(x) && isArr(cw));
SGetU(cw)
usz ia = IA(cw);
ux len = SH(x)[axis];
for (ux i = 0; i < ia; i++) {
f64 c = o2fG(GetU(cw,i));
if (!q_fi64(c)) thrF("𝕨⊏𝕩: Bad index: %f along axis %z", c, axis);
WRAP(c, len, thrF("𝕨⊏𝕩: Indexing out-of-bounds along axis %z (%f ∊ %z⊑𝕨, %H≡≢𝕩)", axis, c, axis, x));
}
fatal("select_depth2_bad_inds should've errored");
}
B add_c2(B,B,B);
B mul_c2(B,B,B);
B lt_c2(B,B,B);
typedef struct {
B inds; // array (rank≤1) of elNum arrays
ux prod; // ×´⥊≠∘⥊¨inds
ux left; // (=x) - ≠𝕨
ux rank; // =w⊏x
bool lastMaybeRange; // whether ¯1⊑r may be an a+↕b
} Depth2Inds;
Depth2Inds select_depth2_parse_inds(B w, B x) { // consumes w; checks that w of w⊏x is valid, and, if so, returns a depth-2 array of number arrays (i.e. `squeeze_numTry¨ ((≠w)↑≢x) | w`)
assert(isArr(w) && isArr(x));
usz wia = IA(w);
assert(wia > 0);
if (RNK(w) > 1) { select_depth2_bad(w,x); fatal("should've errored"); }
SGetU(w)
if (wia > RNK(x)) { select_depth2_bad(w,x); thrF("𝕨⊏𝕩: Compound 𝕨 must not be longer than 𝕩 (%s ≡ ≠𝕨, %H ≡ ≢𝕩)", wia, x); }
Depth2Inds r;
HArr_p inds = m_harr0v(wia);
r.inds = inds.b;
r.prod = 1;
r.rank = r.left = RNK(x) - wia;
usz* xsh = SH(x);
i64 bounds[2];
ux lastIA;
for (ux i = 0; i < wia; i++) {
B c = GetU(w, i);
if (!isArr(c)) { select_depth2_bad(w,x); thrF("𝕨⊏𝕩: Elements of compound 𝕨 must be arrays (encountered %S)", genericDesc(c)); }
r.rank+= RNK(c);
u8 ce;
c = squeeze_numTry(incG(c), &ce, SQ_NUM);
lastIA = IA(c);
r.prod*= lastIA;
if (lastIA>0) {
if (!elNum(ce)) { select_depth2_bad(w,x); thrM("𝕨⊏𝕩: Elements of compound 𝕨 must be arrays of numbers"); }
if (!getRange_fns[ce](tyany_ptr(c), bounds, lastIA) || bounds[0] < -(i64)xsh[i] || bounds[1] >= xsh[i]) {
select_depth2_bad(w,x);
select_depth2_bad_inds(c, i, x);
}
if (bounds[0] < 0) {
c = C2(add, c, C2(mul, m_f64(xsh[i]), C2(lt, incG(c), m_f64(0))));
}
}
inds.a[i] = c;
}
decG(w);
if (r.rank > UR_MAX) thrM("𝕨⊏𝕩: Result rank too large");
r.lastMaybeRange = lastIA!=0 && (bounds[0]<0 || bounds[1]+1-bounds[0] == lastIA);
return r;
}
B tbl_c2(Md1D* d, B w, B x);
B mul_c2(B, B, B);
B ud_c1(B, B);
B shape_c1(B, B);
typedef struct {
B starts; // unspecified shape
ux span, mul, add;
} Spans;
Spans select_depth2_inds(B w, B x, bool lastMaybeRange) { // doesn't consume; assumes w is a result of select_depth2_parse_inds; (⥊w⊏x) ≡ ⥊((mul×⥊starts)+⌜add+↕span)⊏x
ur xr = RNK(x);
usz wia = IA(w);
assert(xr>0 && wia>=2 && wia<=xr);
usz* xsh = SH(x);
ux add = 0;
ux span = shProd(xsh, wia, xr);
ux mul = span;
if (MAY_T(lastMaybeRange) && wia>=2) {
B l = IGetU(w, wia-1);
usz lia = IA(l);
assert(lia!=0);
if (HEURISTIC(lia <= 3)) goto lastNotRange; // TODO improve heuristic
SGetU(l)
usz l0 = o2sG(GetU(l,0));
for (ux i = 1; i < lia; i++) if (o2sG(GetU(l,i)) != l0+i) goto lastNotRange; // TODO do in a less bad way
add = l0*mul;
span*= lia;
mul*= xsh[wia-1];
wia--;
}
lastNotRange:;
SGet(w)
B c = Get(w, 0);
for (ux i = 1; i < wia; i++) {
c = C2(mul, c, m_f64(xsh[i]));
c = M1C2(tbl, add, c, Get(w, i));
}
return (Spans) {c, span, mul, add};
}
#define WRAP_SELECT_ONE(VAL, LEN, FMT, ARG) WRAP(VAL, LEN, thrF("𝕨⊏𝕩: Indexing out-of-bounds (" FMT "∊𝕨, %s≡≠𝕩)", ARG, LEN))
static NOINLINE B select_depth2_select(Spans ws, B x) { // consumes ws.starts
B wst = ws.starts;
ux span = ws.span;
B r0;
if (span != 1) {
u8 xe = TI(x,elType);
ux sia = IA(wst);
UntaggedArr r = m_arrp_copyFill(x, sia*span);
SGetU(wst)
ux ro = 0;
if (xe != el_B) {
void* xp = tyany_ptr(x);
u8 ewb = elwBitLog(xe);
CFRes f = cf_get(span, 1<<ewb);
for (ux i = 0; i < sia; i++) {
ux xi = ws.add + ws.mul*o2sG(GetU(wst,i));
cf_call(f, r.data, ro, xp, xi*f.mul);
ro+= span*f.mul;
}
} else {
for (ux i = 0; i < sia; i++) {
ux xi = ws.add + ws.mul*o2sG(GetU(wst,i));
COPY_TO(r.data, xe, ro, x, xi, span);
ro+= span;
}
NOGC_E;
}
decG(wst);
r0 = taga(r.obj);
} else {
if (ws.mul != 1) wst = C2(mul, wst, m_f64(ws.mul));
// if (span != 1) wst = M1C2(tbl, add, wst, C1(ud, m_f64(span)));
if (ws.add != 0) wst = C2(add, wst, m_f64(ws.add));
r0 = C2(select, wst, C1(shape, incG(x)));
}
assert(reusable(r0));
return r0;
}
static B select_depth2_impl(ux wia, B w, B x) { // wia<=1 only if invalid; or if x is empty, w may also be a number list
Depth2Inds wi = select_depth2_parse_inds(w, x);
assert(IA(w)>=2); // invalid cases thrown out above
w = wi.inds;
ux rr = wi.rank;
B r0;
if (wi.prod==0) {
r0 = taga(emptyArr(x, rr));
} else {
assert(IA(x)!=0);
Spans ws = select_depth2_inds(w, x, wi.lastMaybeRange);
r0 = select_depth2_select(ws, x);
}
if (rr >= 2) {
ShArr* rsh = m_shArr(rr);
arr_shReplace(a(r0), rr, rsh);
usz* rshc = rsh->a;
SGetU(w)
for (ux i = 0; i < wia; i++) {
B wc = GetU(w, i);
ur wcr = RNK(wc);
shcpy(rshc, SH(wc), wcr);
rshc+= wcr;
}
shcpy(rshc, SH(x)+wia, wi.left);
assert(rshc+wi.left == rsh->a+rr);
} else {
arr_shErase(a(r0), rr); // may re-write the rank of bi_emptyHVec/bi_emptyIVec/bi_emptyCVec/bi_emptySVec. ¯\_(ツ)_/¯
}
decG(w); decG(x);
return r0;
}
B select_c2(B t, B w, B x) {
if (isAtm(x)) thrM("𝕨⊏𝕩: 𝕩 cannot be an atom");
ur xr = RNK(x);
if (xr==0) thrM("𝕨⊏𝕩: 𝕩 cannot be a unit");
if (isAtm(w)) {
atomw:;
usz wi = WRAP_SELECT_ONE(o2i64(w), *SH(x), "%R", w);
if (xr==1) return select_list_cell(wi, x);
usz csz = arr_csz(x);
Arr* r = TI(x,slice)(incG(x), wi*csz, csz);
usz* sh = arr_shAlloc(r, xr-1);
if (sh) shcpy(sh, SH(x)+1, xr-1);
decG(x);
return taga(r);
}
usz wia = IA(w);
Arr* r;
ur wr = RNK(w);
i32 rr = xr+wr-1; // only for depth-1 w
if (wia <= 1) {
if (wia == 0) {
emptyRes:
if (0 == *SH(x) && wr==1) {
decG(w);
return x;
}
r = emptyArr(x, rr);
if (rr<=1) goto dec_ret;
goto setsh;
}
B w0 = IGetU(w, 0);
if (isAtm(w0)) {
inc(w0);
decG(w);
w = w0;
if (wr == 0) goto atomw;
assert(rr >= 1);
usz wi = WRAP_SELECT_ONE(o2i64(w), *SH(x), "%R", w);
B r;
usz* sh;
if (xr == 1) {
r = select_list_cell(wi, x);
sh = arr_shAlloc(a(r), rr);
} else {
usz csz = arr_csz(x);
Arr* ra = TI(x,slice)(incG(x), wi*csz, csz);
sh = arr_shAlloc(ra, rr);
if (sh) shcpy(sh+wr, SH(x)+1, xr-1);
r = taga(ra);
decG(x);
}
if (sh) PLAINLOOP for (ux i = 0; i < wr; i++) sh[i] = 1;
return r;
} else if (isArr(w0) && wr<=1) {
// try to fast-path ⟨numarr⟩ ⊏ 𝕩; if not possible, 𝕨 is definitely erroneous
inc(w0);
decG(w);
u8 w0e = TI(w0,elType);
if (elNum(w0e)) return C2(select, w0, x);
w0 = squeeze_numTry(w0, &w0e, SQ_MSGREQ(SQ_NUM));
if (elNum(w0e)) return C2(select, w0, x);
w = m_vec1(w0); // erroneous, speed doesn't matter
}
goto depth2;
}
B xf = getFillR(x);
usz xn = *SH(x);
if (xn==0) goto error_dec_xf; // empty x, non-empty w; error
usz csz = arr_csz(x);
u8 xl = cellWidthLog(x);
usz ria = wia * csz;
usz xia = IA(x);
u8 xe = TI(x,elType);
u8 we = TI(w,elType);
#if SINGELI_AVX2 || SINGELI_NEON
#define CPUSEL(W, NEXT) /*assumes 3≤xl≤6*/ \
if (RARE(!SIMD_SELECT(we, xl)(wp, xp, rp, wia, xn))) select_properError(w, x);
#else
#define CASE(S, E) case S: for (usz i=i0; i<i1; i++) ((E*)rp)[i] = ((E*)xp+off)[ip[i]]; break
#define CASEW(S, E) case S: for (usz i=0; i<wia; i++) ((E*)rp)[i] = ((E*)xp)[WRAP_SELECT_ONE(wp[i], xn, "%i", wp[i])]; break
#define CPUSEL(W, NEXT) /*assumes 3≤xl≤6*/ \
if (sizeof(W) >= 4) { \
switch(xl) { default:UD; CASEW(3,u8); CASEW(4,u16); CASEW(5,u32); CASEW(6,u64); } \
} else { \
W* wt = NULL; \
for (usz bl=(1<<14)/sizeof(W), i0=0, i1=0; i0<wia; i0=i1) { \
i1+=bl; if (i1>wia) i1=wia; \
W min=wp[i0], max=min; for (usz i=i0+1; i<i1; i++) { W e=wp[i]; if (e>max) max=e; if (e<min) min=e; } \
if (min<-(i64)xn) select_properError(w, x); \
if (max>=(i64)xn) select_properError(w, x); \
W* ip=wp; usz off=xn; \
if (max>=0) { off=0; if (RARE(min<0)) { \
if (RARE(xn > (1ULL<<(sizeof(W)*8-1)))) { w=taga(NEXT(w)); mm_free((Value*)r); return C2(select, w, x); } \
if (!wt) {wt=TALLOCP(W,i1-i0);} ip=wt-i0; \
for (usz i=i0; i<i1; i++) { W e=wp[i]; ip[i]=e+((W)xn & (W)-(e<0)); } \
} } \
switch(xl) { default:UD; CASE(3,u8); CASE(4,u16); CASE(5,u32); CASE(6,u64); } \
} \
if (wt) TFREE(wt); \
}
#endif
#if SINGELI_AVX2 || SINGELI_NEON
bool bool_use_simd = we==el_i8 && xl==0 && xia<=128;
#define BOOL_SPECIAL(W) \
if (sizeof(W)==1 && bool_use_simd) { \
if (RARE(!simd_select_bool128(wp, xp, rp, wia, xn))) select_properError(w, x); \
goto setsh; \
}
#else
bool bool_use_simd = 0;
#define BOOL_SPECIAL(W)
#endif
if (!bool_use_simd && xe==el_bit && (csz&7)!=0 && HEURISTIC(xl==0? wia>=256 : wia>=4) && csz<128 && TI(w,arrD1)) {
// test widen/narrow on bitarr input
// ShArr* sh = RNK(x)==1? NULL : ptr_inc(shObj(x));
// B t = C2(select, w, widenBitArr(x, 1));
// B r = narrowWidenedBitArr(t, wr, xr-1, sh==NULL? &xn : sh->a+1);
// if (sh!=NULL) ptr_dec(sh);
// return r;
if (csz==1) {
if (wia/4>=xia) return taga(cpyBitArr(C2(select, w, taga(cpyI8Arr(x)))));
} else if (HEURISTIC(csz>64? wia/2>=xn : wia>=xn/2)) {
ShArr* sh = ptr_inc(shObj(x));
B t = C2(select, w, widenBitArr(x, 1));
B r = narrowWidenedBitArr(t, wr, xr-1, sh->a+1);
ptr_dec(sh);
return r;
}
}
#define TYPE(W, NEXT) { W* wp = W##any_ptr(w); \
if (xl==0) { u64* xp=bitany_ptr(x); \
u64* rp; r = m_bitarrp(&rp, ria); \
BOOL_SPECIAL(W) \
u64 b=0; \
for (usz i = wia; ; ) { \
i--; \
usz n = WRAP(wp[i], xn, select_properError(w, x)); \
b = 2*b + ((((u8*)xp)[n/8] >> (n%8)) & 1); \
if (i%64 == 0) { rp[i/64]=b; if (!i) break; } \
} \
goto setsh; \
} \
if (xe!=el_B) { \
if (xl<3 || xl==7) goto generic_l; \
void* rp = m_tyarrlp(&r, xl-3, ria, arrNewType(TY(x))); \
void* xp = tyany_ptr(x); \
CPUSEL(W, NEXT) \
goto setsh; \
} \
if (xl!=6) goto generic_l; \
M_HARR(ra, wia); B* xp = arr_bptr(x); \
SLOWIF(xp==NULL) SLOW2("𝕨⊏𝕩", w, x); \
if (xp!=NULL) { for (usz i=0; i<wia; i++) HARR_ADD(ra, i, inc(xp[WRAP_SELECT_ONE(wp[i], xia, "%i", wp[i])])); } \
else { SGet(x); for (usz i=0; i<wia; i++) HARR_ADD(ra, i, Get(x, WRAP_SELECT_ONE(wp[i], xia, "%i", wp[i]) )); } \
r = a(withFill(HARR_FV(ra), xf)); goto setsh; \
}
retry:
switch (we) { default: UD;
case el_bit: {
if (xr!=1) {
if (xe!=el_B && (csz<<elwBitLog(xe)) < 128) {
dec(xf);
return C2(select, taga(cpyI8Arr(w)), x);
} else {
goto generic_l;
}
}
SGetU(x)
B x0 = GetU(x, 0);
B x1;
if (xia<2) {
u64* wp=bitany_ptr(w);
usz i; for (i=0; i<wia/64; i++) if (wp[i]) break;
if (i<wia/64 || bitp_l0(wp,wia)!=0) thrF("𝕨⊏𝕩: Indexing out-of-bounds (1∊𝕨, %s≡≠𝕩)", xn);
x1 = x0;
} else {
x1 = GetU(x,1);
}
B r = bit_sel(w, x0, x1);
decG(x);
if (noFill(xf) && TI(r,elType)!=el_B) return taga(cpyHArr(r));
return withFill(r, xf);
}
case el_i8: TYPE(i8, cpyI16Arr)
case el_i16: TYPE(i16,cpyI32Arr)
case el_i32: TYPE(i32,cpyF64Arr)
case el_f64: {
if (MAY_T(FL_HAS(w, fl_squoze))) goto generic_l; // either has non-integers (i.e. error, thus don't care about speed) or very large (i.e. will hit memory bandwidth anyway)
// else fallthrough - want to do integer 𝕨 if possible
}
case el_B: case el_c8: case el_c16: case el_c32: {
w = squeeze_numTry(w, &we, SQ_MSGREQ(SQ_NUM));
if (RANDOMIZE_HEURISTICS && we==el_f64) goto generic_l; // avoid infinite loop
if (elNum(we)) goto retry;
goto error_dec_xf; // erroneous input
}
}
#undef CASE
#undef CASEW
error_dec_xf:;
dec(xf);
depth2: return select_depth2_impl(wia, w, x);
generic_l: {
if (xia==0) goto emptyRes;
SLOW2("𝕨⊏𝕩", w, x);
SGetU(w)
usz csz = arr_csz(x);
CFRes f = cf_get(1, csz<<elwBitLog(xe));
MAKE_MUT_INIT(rm, ria, xe);
usz i = 0;
if (xe<el_B && elInt(we)) {
void* wp = tyany_ptr(w);
void* xp = tyany_ptr(x);
ux ri = 0;
switch(we) { default: UD;
case el_bit: for (; i<wia; i++) { ux c = bitp_get(wp,i); if (c >= xn) { goto bad1; } cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; } break; // TODO something better
case el_i8: { i8* w0=wp; for (i8* wc=w0; wc<w0+wia; wc++) { usz c = WRAP(*wc, xn, { i=wc-w0; goto bad1; }); cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; } } break;
case el_i16: { i16* w0=wp; for (i16* wc=w0; wc<w0+wia; wc++) { usz c = WRAP(*wc, xn, { i=wc-w0; goto bad1; }); cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; } } break;
case el_i32: { i32* w0=wp; for (i32* wc=w0; wc<w0+wia; wc++) { usz c = WRAP(*wc, xn, { i=wc-w0; goto bad1; }); cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; } } break;
}
assert(!isVal(xf));
r = a(mut_fv(rm));
} else {
MUTG_INIT(rm);
for (; i < wia; i++) {
B cw = GetU(w, i); // assumed number from previous squeeze
if (!q_i64(cw)) { bad_cw: goto bad1; }
usz c = WRAP(o2i64G(cw), xn, goto bad_cw; );
mut_copyG(rm, i*csz, x, csz*c, csz);
}
r = a(withFill(mut_fv(rm), xf));
}
goto setsh;
bad1:;
mut_pfree(rm, i*csz);
f64 badw = o2i64(IGetU(w,i));
thrF("𝕨⊏𝕩: Indexing out-of-bounds (%f∊𝕨, %s≡≠𝕩)", badw, xn);
}
setsh:
if (rr>1) {
if (rr > UR_MAX) thrF("𝕨⊏𝕩: Result rank too large (%i≡=𝕨, %i≡=𝕩)", wr, xr);
ShArr* sh = m_shArr(rr);
shcpy(sh->a, SH(w), wr);
shcpy(sh->a+wr, SH(x)+1, xr-1);
arr_shSetUG(r, rr, sh);
} else {
arr_shVec(r);
}
dec_ret:;
decG(w); decG(x); return taga(r);
}
B select_replace(u32 chr, B w, B x, B rep, usz wia, usz cam, usz csz) { // consumes all; (⥊rep)⌾(⥊w⊏cam‿csz⥊⊢) x; assumes csz>0, that w is a typed (elNum) list of valid indices (squeeze already attempted on el_f64), and that rep has the proper element count
assert(csz > 0);
#if CHECK_VALID
TALLOC(bool, set, cam);
bool sparse = wia < cam/64;
if (!sparse) for (i64 i = 0; i < cam; i++) set[i] = false;
#define SPARSE_INIT(WI) \
if (sparse) for (usz i = 0; i < wia; i++) { \
i64 cw = WI; if (RARE(cw<0)) cw+= (i64)cam; set[cw] = false; \
}
#define EQ(ITER,F) if (set[cw]) ITER if (F) thrF("𝔽⌾(a⊸%c): Incompatible result elements", chr); set[cw] = true;
#define EQ1(F) EQ(,F)
#define FREE_CHECK TFREE(set)
#else
#define SPARSE_INIT(GET)
#define EQ(ITER,F)
#define EQ1(F)
#define FREE_CHECK
#endif
#define READ_W(N,I) i64 N = (i64)wp[I]; if (RARE(N<0)) N+= (i64)cam
u8 we = TI(w,elType); assert(elNum(we) || wia==0);
u8 xe = TI(x,elType);
u8 re = el_or(xe, TI(rep,elType));
// w = taga(cpyF64Arr(w)); we = el_f64; // test the float path
DIRECTARR_COPY(r, re, x);
B rb = r.obj;
SLOWIF(rb.u!=x.u && cam>100 && wia<cam/50) SLOW2("⌾(𝕨⊸⊏)𝕩 or ⌾(𝕨⊸⊑)𝕩 because not reusable", w, x);
if (we==el_f64) {
f64* wp = f64any_ptr(w);
SPARSE_INIT((i64)wp[i])
if (csz==1) {
SGet(rep)
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
B cn = Get(rep, i);
EQ1(!compatible(DIRECTARR_GETU(r, cw), cn));
DIRECTARR_REPLACE(r, cw, cn);
}
} else {
SGetU(rep)
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
EQ(for (usz j = 0; j < csz; j++), !compatible(DIRECTARR_GETU(r, cw*csz + j), GetU(rep, i*csz + j)));
DIRECTARR_REPLACE_RANGE(r, cw*csz, rep, i*csz, csz); // TODO use cf_*
}
}
goto dec_ret_rb;
}
assert(elInt(we) || wia==0);
w = toI32Any(w);
i32* wp = i32any_ptr(w);
SPARSE_INIT(wp[i])
switch (re) { default: UD;
case el_i8: rep = toI8Any(rep); goto do_u8;
case el_c8: rep = toC8Any(rep); goto do_u8;
case el_i16: rep = toI16Any(rep); goto do_u16;
case el_c16: rep = toC16Any(rep); goto do_u16;
case el_i32: rep = toI32Any(rep); goto do_u32;
case el_c32: rep = toC32Any(rep); goto do_u32;
case el_f64: rep = toF64Any(rep); goto do_f64;
case el_bit: {
assert(TI(rep,elType)==el_bit);
u64* np = bitarr_ptr(rep);
u64* rp = r.data;
if (csz==1) {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
bool cn = bitp_get(np, i);
EQ1(cn != bitp_get(rp, cw));
bitp_set(rp, cw, cn);
}
} else {
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
EQ(for (usz j = 0; j < csz; j++), bitp_get(np, i*csz + j) != bitp_get(rp, cw*csz + j));
COPY_TO(rp, el_bit, cw*csz, rep, i*csz, csz);
}
}
goto dec_ret_rb;
}
case el_B: {
B* rp = r.data;
if (csz==1) {
SGet(rep)
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
B cn = Get(rep, i);
EQ1(!compatible(cn,rp[cw]));
dec(rp[cw]);
rp[cw] = cn;
}
} else {
SGetU(rep)
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
EQ(for (usz j = 0; j < csz; j++), !compatible(GetU(rep, i*csz + j), rp[cw*csz + j]));
for (usz j = 0; j < csz; j++) dec(rp[cw*csz + j]);
COPY_TO(rp, el_B, cw*csz, rep, i*csz, csz);
}
}
goto dec_ret_rb;
}
}
#define IMPL(T, COMPATIBLE) do { \
if (csz!=1) goto do_tycell; \
T* rp = r.data; \
T* np = tyany_ptr(rep); \
for (usz i = 0; i < wia; i++) { \
READ_W(cw, i); \
T cn = np[i]; \
EQ1(!COMPATIBLE(cn, rp[cw])); \
rp[cw] = cn; \
} \
goto dec_ret_rb; \
} while(0)
#define INT_EQ(A,B) ((A)==(B))
do_u8: IMPL(u8, INT_EQ);
do_u16: IMPL(u16, INT_EQ);
do_u32: IMPL(u32, INT_EQ);
do_f64: IMPL(f64, compatibleFloats);
#undef INT_EQ
#undef IMPL
do_tycell:;
u8 cwidth = csz * elWidth(re);
u8* rp = r.data;
u8* np = tyany_ptr(rep);
MatchFnObj eq = MATCHR_GET(re,re);
for (usz i = 0; i < wia; i++) {
READ_W(cw, i);
EQ1(!MATCH_CALL(eq, rp + cw*cwidth, np + i*cwidth, csz));
COPY_TO(rp, re, cw*csz, rep, i*csz, csz);
}
goto dec_ret_rb;
dec_ret_rb:;
decG(w); decG(rep);
FREE_CHECK;
return rb;
#undef SPARSE_INIT
#undef EQ
#undef EQ1
#undef FREE_CHECK
}
static void* m_arrv_same_t(B* r, B** rbp, usz ia, u8 ty, B src) {
assert(isArr(src));
u8 se = TIi(ty,elType);
if (se==el_B) {
B fill = getFillQ(src);
if (noFill(fill)) {
HArr_p p = m_harrUv(ia);
*rbp = p.a;
*r = p.b;
} else {
Arr* ra = m_fillarrp(ia);
fillarr_setFill(ra, fill);
*rbp = fillarrv_ptr(ra);
*r = taga(ra);
}
FILL_TO(*rbp, el_B, 0, m_f64(0), ia);
NOGC_E;
return *rbp;
} else {
return m_tyarrlbv(r, arrTypeBitsLog(ty), ia, arrNewType(ty));
}
}
static void* m_arrv_same(B* r, B** rbp, usz ia, B src) { // makes a new array with same element type and fill as src, but new ia
return m_arrv_same_t(r, rbp, ia, TY(src), src);
}
B slash_c2(B, B, B);
B select_cells_base(B inds, B x0, ux csz, ux cam);
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz) { // ⥊ ind ⊏˘ cam‿l‿csz ⥊ x
Arr* ra;
if (l==1) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else {
u8 xe = TI(x,elType);
u8 ewl= elwBitLog(xe);
u8 xl = multWidthLog(csz, ewl);
usz ria = cam*csz;
if (xl>=7 || (xl<3 && xl>0)) { // generic case
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
usz jump = l * csz;
usz xi = csz*ind;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, csz);
xi+= jump;
ri+= csz;
}
ra = mut_fp(rm);
arr_shVec(ra);
goto copyFill;
} else if (xe==el_B) {
assert(csz == 1);
SGet(x)
HArr_p rp = m_harrUv(ria);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
NOGC_E; ra = (Arr*)rp.c;
goto copyFill;
} else {
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
void* xp = tyany_ptr(x);
if (xl == 0) {
#if SINGELI
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
else
#endif
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
} else {
usz i0 = 0;
#if SINGELI
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
#endif
switch(xl) { default: UD;
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
case 4: PLAINLOOP for (usz i=i0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
case 5: PLAINLOOP for (usz i=i0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
case 6: PLAINLOOP for (usz i=i0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
}
}
}
}
return taga(ra);
copyFill:
return withFill(taga(ra), getFillQ(x));
}
#define CLZC(X) (64-(CLZ((u64)(X))))
#ifdef SELECT_ROWS_PRINTF
#undef SELECT_ROWS_PRINTF
#define SELECT_ROWS_PRINTF(...) printf(__VA_ARGS__)
#else
#define SELECT_ROWS_PRINTF(...)
#endif
#define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes
B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊x; if inds are valid and csz<=128, ie must be <=el_i8
assert(csz!=0 && cam!=0 && indn!=0);
assert(csz*cam == IA(x));
assert(ie<=el_i32);
MAYBE_UNUSED u8 inds_buf[INDS_BUF_MAX];
bool generic_allowed = true; // whether required interpretation of x hasn't changed from its real one
if (csz==1) { // TODO maybe move to select_rows_B and require csz>=2 here?
i64 bounds[2];
if (!getRange_fns[ie](inds, bounds, indn) || bounds[0]<-1 || bounds[1]>0) goto generic_any;
return C2(slash, m_f64(indn), taga(arr_shVec(customizeShape(x))));
}
assert(csz>=2);
ux ria = indn * cam;
B r;
B* rbp = NULL;
u8* xp;
u8 xe = TI(x,elType);
u8 lb = arrTypeWidthLog(TY(x));
if (xe==el_B) {
if (sizeof(B) != 8) goto generic_any;
xp = (u8*) arr_bptr(x);
if (xp == NULL) goto generic_any;
} else {
xp = tyany_ptr(x);
if (xe == el_bit) {
#if SINGELI_AVX2 || SINGELI_NEON
if (indn<=8 && csz<=8) goto bit_ok;
#endif
goto generic_any;
goto bit_ok; bit_ok:;
}
}
MAYBE_UNUSED bool fast;
ux xbump = csz<<lb;
ux rbump = indn<<lb;
i64 bounds[2];
if (ie==el_bit) {
// TODO path for xe==el_bit + long indn
if (HEURISTIC_BOUNDED(csz>32 || indn>32 || indn>INDS_BUF_MAX, xe!=el_bit && (csz>8 || indn>8), indn<=32)) { // TODO properly tune
u8* rp = m_arrv_same(&r, &rbp, ria, x);
for (ux i = 0; i < cam; i++) {
bitselFns[lb](rp, inds, loadu_u64(xp), loadu_u64(xp + (1<<lb)), indn);
xp+= xbump;
rp+= rbump;
}
goto decG_B_ret;
} else {
assert(inds_buf != inds);
COPY_TO_FROM(inds_buf, el_i8, inds, el_bit, indn);
inds = inds_buf;
ie = el_i8;
bounds[0] = 0;
bounds[1] = 1; // might be an over-estimate, hopefully doesn't matter (and csz≥2 anyway)
#if SINGELI
fast = true;
goto skip_bounds_check;
#endif
}
}
#if SINGELI
assert(INDS_BUF_MAX_COPY == INDS_BUF_MAX);
{
fast = ie==el_i8;
if (!getRange_fns[ie](inds, bounds, indn)) goto generic_int;
if (bounds[1] >= (i64)csz) goto generic_int;
if (bounds[0] < 0) {
if (bounds[0] < -(i64)csz) goto generic_int;
if (csz < 128 && indn < INDS_BUF_MAX) {
assert(ie == el_i8);
si_wrap_inds[0](inds, inds_buf, indn, csz);
bounds[0] = 0;
bounds[1] = csz-1;
inds = inds_buf;
} else {
fast = false;
}
}
skip_bounds_check:;
assert(ie==el_i8 || csz>128);
#if SINGELI_AVX2 || SINGELI_NEON
if (fast) {
generic_allowed = false;
ux sh = select_rows_widen[lb](inds, inds_buf, bounds[1], indn); // TODO null element in table for guaranteed-zero
if (sh!=0) {
SELECT_ROWS_PRINTF("widening indices by factor of %d:\n", 1<<sh);
SELECT_ROWS_PRINTF(" src: lb=%d, ie=%d, csz=%zu, indn=%zu\n", lb, ie, csz, indn);
inds = inds_buf;
lb-= sh;
csz<<= sh;
indn<<= sh;
SELECT_ROWS_PRINTF(" dst: lb=%d, ie=%d, csz=%zu, indn=%zu\n", lb, ie, csz, indn);
}
}
#endif
#if SINGELI_AVX2 || SINGELI_NEON
if (xe==el_bit) {
assert(ie==el_i8 && csz<=8 && indn<=8 && csz>=2 && indn>=1);
// TODO si_select_cells_bit_lt64 for indn==1
static const u8 rep_lut[9] = {0,3,2,1,1,0,0,0,0};
u8 exp = rep_lut[csz>indn? csz : indn];
ux rindn = indn<<exp;
ux rcsz = csz<<exp;
assert(rcsz<=8 && rindn<=8);
ux rcam = (cam + (1<<exp)-1)>>exp;
if (rcsz!=8) {
u64* xp2;
B x2 = m_bitarrv(&xp2, 8*cam);
bitwiden(xp2, 8, xp, rcsz, cam);
decG(x);
x = x2;
xp = (void*) xp2;
SELECT_ROWS_PRINTF("8bit: widen %zu‿%zu → ⟨%zu,%zu→8⟩\n", cam, csz, rcam, rcsz);
}
if (exp!=0) {
simd_repeat_inds(inds, inds_buf, indn, csz);
inds = inds_buf;
}
u64* rp;
ux ria0 = rindn!=8? 8*rcam : ria;
r = m_bitarrv(&rp, ria0);
SELECT_ROWS_PRINTF("8bit: indn=%zu rindn=%zu csz=%zu rcsz=%zu cam=%zu ria0=%zu rcam=%zu\n", indn, rindn, csz, rcsz, cam, ria0, rcam);
si_select_rows_8bit(inds, rindn, xp, rp, (ria0+7)/8);
if (rindn!=8) {
SELECT_ROWS_PRINTF("8bit: narrow 8 → %zu<<%d\n", csz, exp);
u64* rp2;
B r2 = m_bitarrv(&rp2, 8*rcam);
bitnarrow(rp2, rindn, rp, 8, rcam);
tyarrv_free(r);
r = r2;
ux ria1 = IA(r);
assert(ria <= ria1);
FINISH_OVERALLOC(a(r), offsetof(TyArr,a) + (ria+7)/8, offsetof(TyArr,a) + (ria1+7)/8);
a(r)->ia = ria;
}
goto decG_ret;
}
#endif
u8* rp = m_arrv_same(&r, &rbp, ria, x);
ux slow_cam = cam;
#if SINGELI_AVX2 || SINGELI_NEON
ux lnt = CLZC(csz-1); // ceil-log2 of number of elements in table
if (fast && lnt < select_rows_tab_h) {
u8 max_indn = select_rows_max_indn[lb];
if (indn > max_indn) goto no_fast;
u8 min_lnt = select_rows_min_logcsz[lb];
ux used_lnt;
SELECT_ROWS_PRINTF("csz: %zu/%d; inds: %d/%d\n", csz, 1<<min_lnt, (int)indn, max_indn);
ux indn_real = indn;
ux rep;
if (indn*2 <= max_indn) {
assert(max_indn<=32); // otherwise inds_buf hard-coded size may need to change
rep = simd_repeat_inds(inds, inds_buf, indn, csz);
indn_real = rep*indn;
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - raw repeat\n", rep, indn, indn*rep, csz, csz*rep);
PLAINLOOP while (rep*indn > max_indn) rep--; // simd_repeat_inds over-estimates
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid inds\n", rep, indn, indn*rep, csz, csz*rep);
used_lnt = min_lnt;
ux fine_csz = 1ULL<<(min_lnt+1); // TODO have a proper per-element-type LUT of "target LUT size"
if (csz < fine_csz) {
ux cap = fine_csz / csz;
if (rep > cap) rep = cap;
} else rep = 1;
ux new_lnt = CLZC(csz*rep-1);
if (new_lnt > used_lnt) used_lnt = new_lnt;
SELECT_ROWS_PRINTF("rep: %zu; inds: %zu→%zu; csz: %zu→%zu - valid table\n", rep, indn, indn*rep, csz, csz*rep);
inds = inds_buf;
} else {
rep = 1;
used_lnt = lnt<min_lnt? min_lnt : lnt;
}
assert(indn*rep <= max_indn);
AUTO fn = select_rows_tab[used_lnt*4 + lb];
if (fn == null_fn) goto no_fast;
ux done = fn(inds, xp, csz*rep, rp, indn*rep, rp + cam*rbump) * rep;
ux left = cam - done;
SELECT_ROWS_PRINTF("done_rows: %zu; left_rows: %zu; left_els: %zu; left_max: %zu\n", done, left, indn*left, indn_real);
if (left) {
xp+= done * xbump;
rp+= done * rbump;
if (left*csz <= 127) {
assert(indn*left <= indn_real);
bool ok = SIMD_SELECT(ie, lb+3)(inds, xp, rp, indn*left, I64_MAX); assert(ok);
} else {
slow_cam = left;
goto no_fast;
}
}
goto decG_B_ret;
}
no_fast:;
#endif
SimdSelectFn fn = SIMD_SELECT(ie, lb+3);
for (ux i = 0; i < slow_cam; i++) {
fn(inds, xp, rp, indn, csz);
xp+= xbump;
rp+= rbump;
}
goto decG_B_ret;
}
#else
(void) bounds;
#endif
generic_any:;
if (ie==el_bit) {
u64* rp;
B indo = m_bitarrv(&rp, indn);
memcpy(rp, inds, (indn+7)>>3);
assert(generic_allowed);
return select_cells_base(indo, x, csz, cam);
}
goto generic_int;
generic_int:;
assert(ie!=el_bit && generic_allowed);
B indo = taga(arr_shVec(m_tyslice(inds, a(emptyIVec()), t_i8slice + ie-el_i8, indn)));
return select_cells_base(indo, x, csz, cam);
decG_B_ret:;
if (rbp != NULL) {
for (ux i = 0; i < ria; i++) inc(rbp[i]); // TODO if only a few columns are selected, could incBy in a stride per selected column
}
decG_ret: MAYBE_UNUSED;
decG(x);
return r;
}
B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸⊏˘ cam‿csz⥊x; if inds isn't rank 1, result may or may not be high rank
assert(csz*cam == IA(x));
if (csz==0) goto generic;
if (cam<=1) {
if (cam==0) return taga(emptyVec(x));
return C2(select, inds, taga(arr_shVec(TI(x,slice)(x, 0, IA(x)))));
}
ux in = IA(inds);
if (in == 0) return taga(emptyVec(x));
if (in == 1) {
B w = IGetU(inds,0); if (!isF64(w)) goto generic;
B r = select_cells_single(WRAP_SELECT_ONE(o2i64(w), csz, "%R", w), x, cam, csz, 1);
decG(x); decG(inds); return r;
}
u8 ie = TI(inds,elType);
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
inds = squeeze_numTry(inds, &ie, SQ_BEST);
if (!elInt(ie)) goto generic;
}
void* ip = tyany_ptr(inds);
B r = select_rows_direct(x, csz, cam, (u8*)ip, in, ie);
decG(inds);
return r;
generic:;
return select_cells_base(inds, x, csz, cam);
}
SHOULD_INLINE i64 i64get_i32(void* xp, ux i, bool* bad) {
return ((i32*)xp)[i];
}
SHOULD_INLINE i64 i64get_f64(void* xp, ux i, bool* bad) {
f64 f = ((f64*)xp)[i];
if (q_fi64(f)) return (i64)f;
*bad = true;
return 0;
}
SHOULD_INLINE bool select_each_impl(DirectArr r, u8 re, B c, ux xn, void* wp, usz wia, i64 (*getW)(void*, ux, bool*)) {
#define GETW ({ bool bad=false; i64 wc = getW(wp, i, &bad); if (bad) goto bad; WRAP(wc, xn, goto bad); })
u64 uval;
switch (re) { default: UD;
case el_bit:;
bool cb = o2bG(c);
if (cb) for (ux i = 0; i < wia; i++) bitp_set(r.data, GETW, true);
else for (ux i = 0; i < wia; i++) bitp_set(r.data, GETW, false);
return true;
case el_i8: uval = o2iG(c); goto do_u8;
case el_i16: uval = o2iG(c); goto do_u16;
case el_i32: uval = o2iG(c); goto do_u32;
case el_f64: uval = r_f64u(o2fG(c)); goto do_u64;
case el_c8: uval = o2cG(c); goto do_u8;
case el_c16: uval = o2cG(c); goto do_u16;
case el_c32: uval = o2cG(c); goto do_u32;
case el_B:
for (ux i = 0; i < wia; i++) {
B* p = (B*)r.data + GETW;
dec(*p);
*p = c;
}
return true;
}
UD;
do_u8: for (ux i = 0; i < wia; i++) ((u8 *)r.data)[GETW] = uval; return true;
do_u16: for (ux i = 0; i < wia; i++) ((u16*)r.data)[GETW] = uval; return true;
do_u32: for (ux i = 0; i < wia; i++) ((u32*)r.data)[GETW] = uval; return true;
do_u64: for (ux i = 0; i < wia; i++) ((u64*)r.data)[GETW] = uval; return true;
#undef GETW
bad: return false;
}
B select_ucw(B t, B o, B w, B x) {
if (RARE(isAtm(x))) { def: return def_fn_ucw(t, o, w, x); }
u8 we;
if (isAtm(w)) {
if (RARE(!isNum(w))) goto def;
w = m_unit(w);
we = TI(w,elType);
assert(elNum(we));
} else {
we = TI(w,elType);
if (!elInt(we)) {
w = squeeze_numTry(w, &we, SQ_MSGREQ(SQ_NUM));
if (!elNum(we)) goto def;
}
}
usz wia = IA(w);
B rep;
if (MAY_F(isArr(o) && RNK(x)>0)) {
usz xn = *SH(x);
i64 buf[2];
if (wia!=0 && (!getRange_fns[we](tyany_ptr(w), buf, wia) || buf[0]<-(i64)xn || buf[1]>=xn)) {
bad:
C2(select, w, x);
fatal("select_ucw expected to error");
}
rep = incG(o);
} else if (MAY_F(isFun(o) && TY(o)==t_md1D && RNK(x)==1)) {
Md1D* od = c(Md1D,o);
if (PRTID(od->m1) != n_each) goto notConstEach;
B c;
if (!toConstant(od->f, &c)) goto notConstEach;
u8 ce = selfElType(c);
u8 xe = TI(x,elType);
u8 re = el_or(ce,xe);
DirectArr r = toEltypeArr(x, re);
if (isVal(c)) {
if (wia==0) decG(c); // TODO could return x; fills?
else incByG(c, wia-1);
}
usz xn = *SH(r.obj);
bool ok;
if (elInt(we)) {
w = toI32Any(w); we = el_i32;
i32* wp = i32any_ptr(w);
ok = select_each_impl(r, re, c, xn, wp, wia, i64get_i32);
} else {
// annoying amount of code for el_f64 𝕨 vs ≤el_i32, but as el_f64 only should apply to ≥2⋆31-element arrays it shouldn't matter much
assert(we==el_f64);
f64* wp = f64any_ptr(w);
ok = select_each_impl(r, re, c, xn, wp, wia, i64get_f64);
}
if (ok) {
decG(w);
return r.obj;
} else {
x = r.obj;
goto bad;
}
} else {
notConstEach:;
rep = c1(o, C2(select, incG(w), incG(x)));
}
ur xr = RNK(x);
ur wr = RNK(w);
if (isAtm(rep) || xr+wr != RNK(rep)+1 || !eqShPart(SH(w),SH(rep),wr) || !eqShPart(SH(x)+1,SH(rep)+wr,xr-1)) {
thrF("𝔽⌾(a⊸⊏)𝕩: 𝔽 must return an array with the same shape as its input (expected %0H, got %0H)", C2(select, w, x), rep);
}
usz csz = arr_csz(x);
if (csz == 0) { decG(rep); decG(w); return x; }
return select_replace(U'', w, x, rep, wia, *SH(x), csz);
}