Move select_cells_single (atom⊸⊏˘) to select.c and use for any singleton index

This commit is contained in:
Marshall Lochbaum 2024-10-29 22:38:06 -04:00
parent 0bdc43cc0f
commit d7b508ff3b
2 changed files with 76 additions and 57 deletions

View File

@ -219,9 +219,8 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
// fast special-case implementations
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf); // from select.c
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x
ur xr = RNK(x);
assert(xr>1 && k<xr);
usz* xsh = SH(x);
@ -229,63 +228,15 @@ static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind
usz l = xsh[k];
assert(0<=ind && ind<l);
assert(cam*l*csz == IA(x));
Arr* ra;
usz take = leaf? 1 : csz;
if (l==1 && take==csz) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else {
u8 xe = TI(x,elType);
u8 ewl= elwBitLog(xe);
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
usz ria = cam*take;
if (xl>=7 || (xl<3 && xl>0)) { // generic case
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
usz jump = l * csz;
usz xi = take*ind;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, take);
xi+= jump;
ri+= take;
}
ra = mut_fp(rm);
} else if (xe==el_B) {
assert(take == 1);
SGet(x)
HArr_p rp = m_harrUv(ria);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
NOGC_E; ra = (Arr*)rp.c;
} else {
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
void* xp = tyany_ptr(x);
if (xl == 0) {
#if SINGELI
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
else
#endif
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
} else {
usz i0 = 0;
#if SINGELI
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
#endif
switch(xl) { default: UD;
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
case 4: PLAINLOOP for (usz i=i0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
case 5: PLAINLOOP for (usz i=i0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
case 6: PLAINLOOP for (usz i=i0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
}
}
}
}
B r = select_cells_single(ind, x, cam, l, csz, leaf);
Arr* ra = a(r);
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
if (rsh) {
shcpy(rsh, xsh, k);
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
}
decG(x);
return taga(ra);
return r;
}
static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores

View File

@ -40,8 +40,15 @@
// Sparse initialization if 𝕨 is much smaller than 𝕩
// COULD call Mark Firsts (∊) for very short 𝕨 to avoid allocation
// Select Cells - inds⊸⊏⎉1 x
// Squeeze indices if too wide for given x
// Select Cells - inds⊸⊏⎉1 𝕩
// Squeeze indices if too wide for given 𝕩
// Single index: (also used for monadic ⊏˘ ⊣˝˘ ⊢˝˘)
// Selecting a column of bits:
// Row size <64: extract as with fold-cells
// Selecting a column of 1, 2, 4, or 8-byte elements:
// Short cells: pack vectors from 𝕩, or blend and permute
// Long cells: dedicated scalar loop for each type
// Otherwise, loop with mutable copy
// Boolean indices:
// Short inds and short cells: Widen to i8
// Otherwise: bitsel call per cell
@ -57,7 +64,7 @@
// COULD generate full list of indices via arith
// 1-element cells: use (≠inds)/⥊x after checking ∧´inds∊0‿¯1
// Used for ⌽⎉1
// SHOULD use for atom⊸⊏⎉k, /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
// SHOULD use for /⎉k, ⌽⎉k, ↑⎉k, ↓⎉k, ↕⎉k, ⍉⎉k, probably more
#include "../core.h"
#include "../utils/talloc.h"
@ -575,6 +582,62 @@ static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same e
B slash_c2(B, B, B);
B select_cells_base(B inds, B x0, ux csz, ux cam);
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf) { // ⥊ ind {leaf? <∘⊑; ⊏}˘ cam‿l‿csz ⥊ x
usz take = leaf? 1 : csz;
Arr* ra;
if (l==1 && take==csz) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else {
u8 xe = TI(x,elType);
u8 ewl= elwBitLog(xe);
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
usz ria = cam*take;
if (xl>=7 || (xl<3 && xl>0)) { // generic case
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
usz jump = l * csz;
usz xi = take*ind;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, take);
xi+= jump;
ri+= take;
}
ra = mut_fp(rm);
} else if (xe==el_B) {
assert(take == 1);
SGet(x)
HArr_p rp = m_harrUv(ria);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
NOGC_E; ra = (Arr*)rp.c;
} else {
void* rp = m_tyarrlbp(&ra, ewl, ria, el2t(xe));
void* xp = tyany_ptr(x);
if (xl == 0) {
#if SINGELI
if (l < 64) si_select_cells_bit_lt64(xp, rp, cam, l, ind);
else
#endif
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*l+ind));
} else {
usz i0 = 0;
#if SINGELI
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
#endif
switch(xl) { default: UD;
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;
case 4: PLAINLOOP for (usz i=i0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*l+ind]; break;
case 5: PLAINLOOP for (usz i=i0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*l+ind]; break;
case 6: PLAINLOOP for (usz i=i0; i<cam; i++) ((f64*)rp)[i] = ((f64*)xp)[i*l+ind]; break;
}
}
}
}
return taga(ra);
}
#define CLZC(X) (64-(CLZ((u64)(X))))
@ -851,6 +914,11 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸
ux in = IA(inds);
if (in == 0) return taga(emptyArr(x, 1));
u8 ie = TI(inds,elType);
if (in == 1) {
B w = IGetU(inds,0); if (!isF64(w)) goto generic;
B r = select_cells_single(WRAP(o2i64(w), csz, thrF("⊏: Indexing out-of-bounds (%R∊𝕨, %s≡≠𝕩)", w, csz)), x, cam, csz, 1, false);
decG(x); decG(inds); return r;
}
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
inds = num_squeeze(inds);
ie = TI(inds,elType);