better 𝕨⊏𝕩 with non-power-of-two cell sizes

This commit is contained in:
dzaima 2023-02-04 03:26:20 +02:00
parent af253e25e5
commit be9d0c287e

View File

@ -3,21 +3,25 @@
// First Cell is just a slice
// Complications in Select mostly come from range checks and negative 𝕨
// Atom 𝕨 and any rank 𝕩: slice
// Rank-1 𝕩:
// Empty 𝕨: no selection
// Small 𝕩 with Singeli: use shuffles
// Boolean 𝕨: use bit_sel for blend or similar
// Boolean 𝕩 and larger 𝕨: convert to i8, select, convert back
// Boolean 𝕩 otherwise: select/shift bytes, reversed for fast writing
// Atom or enclosed atom 𝕨 and rank-1 𝕩: make new array
// Atom or enclosed atom 𝕨 and high-rank 𝕩: slice
// Empty 𝕨: no selection
// Float or generic 𝕨: attempt to squeeze, go generic cell size path if stays float
// Boolean 𝕩 (cell size = 1 bit):
// 𝕨 larger than 𝕩: convert 𝕩 to i8, select, convert back
// Otherwise: select/shift bytes, reversed for fast writing
// TRIED pext, doesn't seem faster (mask built with shifts anyway)
// SHOULD squeeze 𝕨 if not ≤i32 to get to optimized cases
// 𝕩 with cell sizes of 1, 2, 4, or 8 bytes:
// Small 𝕩 and i8 𝕨 with Singeli: use shuffles
// Boolean 𝕨: use bit_sel for blend or similar
// Integer 𝕨 with Singeli: fused wrap, range-check, and gather
// COULD try selecting from boolean with gather
// COULD detect <Skylake where gather is slow
// i32 𝕨: wrap, check, select one index at a time
// i8 and i16 𝕨: separate range check in blocks to auto-vectorize
// SHOULD optimize simple 𝕨 based on cell size for any rank 𝕩
// Generic cell size 𝕩:
// Computes a function that copies the necessary amount of bytes/bits
// Specializes over i8/i16/i32 𝕨
// SHOULD implement nested 𝕨
// Under Select ⌾(i⊸⊏)
@ -39,6 +43,51 @@
#include "../utils/includeSingeli.h"
#endif
typedef size_t ux; // TODO move to h.h
typedef void (*CFn)(void* r, ux rs, void* x, ux xs, ux data);
typedef struct {
CFn fn;
ux data;
ux mul;
} CFRes;
static void cf_0(void* r, ux rs, void* x, ux xs, ux d) { }
static void cf_1(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 1); }
static void cf_2(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 2); }
static void cf_3(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 3); }
static void cf_4(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 4); }
static CFn cfs_0_4[] = {cf_0, cf_1, cf_2, cf_3, cf_4};
static void cf_8(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 8); }
static void cf_5_7 (void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 4); memcpy(r+d, x+d, 4); }
static void cf_9_16 (void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 8); memcpy(r+d, x+d, 8); }
static void cf_17_24(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 16); memcpy(r+d, x+d, 8); }
static void cf_25_32(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, 24); memcpy(r+d, x+d, 8); }
static void cf_arb(void* r, ux rs, void* x, ux xs, ux d) { r=rs+(u8*)r; x=xs+(u8*)x; memcpy(r, x, d); }
static void cfb_1(void* r, ux rs, void* x, ux xs, ux d) { bitp_set(r, rs, bitp_get(x, xs)); }
static void cfb_arb(void* r, ux rs, void* x, ux xs, ux d) { bit_cpy(r, rs, x, xs, d); }
NOINLINE CFRes cf_get(usz count, usz cszBits) {
if ((cszBits&7)==0) {
ux cszBytes = cszBits/8;
ux bytes = cszBytes * (ux)count;
if (bytes<5) return (CFRes){.mul=cszBytes, .fn = cfs_0_4[bytes]};
if (bytes<8) return (CFRes){.mul=cszBytes, .fn = cf_5_7, .data=bytes-4};
if (bytes==8) return (CFRes){.mul=cszBytes, .fn = cf_8};
if (bytes<=16) return (CFRes){.mul=cszBytes, .fn = cf_9_16, .data=bytes-8};
if (bytes<=24) return (CFRes){.mul=cszBytes, .fn = cf_17_24, .data=bytes-8};
if (bytes<=32) return (CFRes){.mul=cszBytes, .fn = cf_25_32, .data=bytes-8};
return (CFRes){.mul=cszBytes, .fn = cf_arb, .data=bytes};
}
ux bits = count*(ux)cszBits;
if (bits==1) return (CFRes){.mul=cszBits, .fn = cfb_1};
else return (CFRes){.mul=cszBits, .fn = cfb_arb, .data=bits};
}
FORCE_INLINE void cf_call(CFRes f, void* r, ux rs, void* x, ux xs) {
f.fn(r, rs, x, xs, f.data);
}
extern B rt_select;
B select_c1(B t, B x) {
if (isAtm(x)) thrM("⊏: Argument cannot be an atom");
@ -234,15 +283,37 @@ B select_c2(B t, B w, B x) {
SLOW2("𝕨⊏𝕩", w, x);
SGetU(w)
usz csz = arr_csz(x);
MAKE_MUT(rm, ria); mut_init(rm, TI(x,elType));
MUTG_INIT(rm);
for (usz i = 0; i < wia; i++) {
B cw = GetU(w, i); // assumed number from previous squeeze
usz c = WRAP(o2i64(cw), xn, { mut_pfree(rm, i*csz); thrF("⊏: Indexing out-of-bounds (%R∊𝕨, %H≡≢𝕩)", cw, x); });
mut_copyG(rm, i*csz, x, csz*c, csz);
CFRes f = cf_get(1, csz<<elWidthLogBits(xe));
MAKE_MUT(rm, ria); mut_init(rm, xe);
usz i = 0; f64 badw;
if (xe<el_B && elInt(we)) {
void* wp = tyany_ptr(w);
void* xp = tyany_ptr(x);
ux ri = 0;
switch(we) { default: UD;
case el_bit: for (; i<wia; i++) { i8 c =bitp_get(wp,i); if (c>=xn) { badw=c; goto bad1; } cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; } // TODO something better
case el_i8: for (; i<wia; i++) { i8 c0=((i8* )wp)[i]; usz c = WRAP(c0, xn, { badw=c0; goto bad1; }); cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; }
case el_i16: for (; i<wia; i++) { i16 c0=((i16*)wp)[i]; usz c = WRAP(c0, xn, { badw=c0; goto bad1; }); cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; }
case el_i32: for (; i<wia; i++) { i32 c0=((i32*)wp)[i]; usz c = WRAP(c0, xn, { badw=c0; goto bad1; }); cf_call(f, rm->a, ri, xp, c*f.mul); ri+= f.mul; }
}
assert(!isVal(xf));
r = a(mut_fv(rm));
} else {
MUTG_INIT(rm);
for (; i < wia; i++) {
B cw = GetU(w, i); // assumed number from previous squeeze
usz c = WRAP(o2i64(cw), xn, { badw=o2fG(cw); goto bad1; });
mut_copyG(rm, i*csz, x, csz*c, csz);
}
r = a(withFill(mut_fv(rm), xf));
}
r = a(withFill(mut_fv(rm), xf));
goto setsh;
bad1:;
mut_pfree(rm, i*csz);
thrF("⊏: Indexing out-of-bounds (%f∊𝕨, %H≡≢𝕩)", badw, x);
}