uCBQN/src/builtins/search.c
2023-02-24 15:53:11 +02:00

359 lines
13 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Dyadic search functions: Member Of (∊), Index of (⊐), Progressive Index of (⊒)
// 𝕨⊐unit or unit∊𝕩: scalar loop with early-exit
// SHOULD use simd
// SHOULD unify implementations
// 𝕩⊒unit or 𝕨⊒𝕩 where 1≥≠𝕩: defer to 𝕨⊐𝕩
// Both arguments with rank≥1:
// High-rank inputs:
// Convert to a (lower-rank) typed integer array if cells are ≤62 bits
// COULD have special hashing for equal type >64 bit cells, skipping squeezing
// COULD try conditionally squeezing ahead-of-time, and not squeezing in bqn_hash
// p⊐n & n∊p with short p & long n: n⊸=¨ p
// bitarr⊐𝕩: more special arithmetic
// SHOULD have impls for long p & short n
// ≤16-bit elements: lookup tables
// Character elements: reinterpret as integer elements
// Otherwise, generic hashtable
// SHOULD handle up to 64 bit cells via proper typed hash tables
// SHOULD have fast path when cell sizes or element types doesn't match
// SHOULD properly handle ¯0
#include "../core.h"
#include "../utils/hash.h"
#include "../utils/talloc.h"
#include "../utils/calls.h"
#define C2i(F, W, X) C2(F, m_i32(W), X)
extern B eq_c2(B,B,B);
extern B ne_c2(B,B,B);
extern B or_c2(B,B,B);
extern B add_c2(B,B,B);
extern B sub_c2(B,B,B);
extern B mul_c2(B,B,B);
static u64 elRange(u8 eltype) { return 1ull<<(1<<elWidthLogBits(eltype)); }
#define TABLE(IN, FOR, TY, INIT, SET) \
usz it = elRange(IN##e); /* Range of writes */ \
usz ft = elRange(FOR##e); /* Range of lookups */ \
usz t = it>ft? it : ft; /* Table allocation width */ \
TALLOC(TY, tab0, t); TY* tab = tab0 + t/2; \
usz m=IN##ia, n=FOR##ia; \
void* ip = tyany_ptr(IN); \
void* fp = tyany_ptr(FOR); \
/* Initialize */ \
if (IN.u != FOR.u) { \
if (FOR##e==el_i16 && n<ft/(64/sizeof(TY))) \
{ for (usz i=0; i<n; i++) tab[((i16*)fp)[i]]=INIT; } \
else { TY* to=tab-(ft/2-(ft==2)); for (i64 i=0; i<ft; i++) to[i]=INIT; } \
} \
/* Set */ \
if (IN##e==el_i8) { for (usz i=m; i--; ) tab[((i8 *)ip)[i]]=SET; } \
else { for (usz i=m; i--; ) tab[((i16*)ip)[i]]=SET; } \
decG(IN); \
/* Lookup */ \
if (FOR##e==el_bit) { \
r = bit_sel(FOR, m_i32(tab[0]), m_i32(tab[1])); \
} else { \
TY* rp; r = m_##TY##arrc(&rp, FOR); \
if (FOR##e==el_i8){ for (usz i=0; i<n; i++) rp[i]=tab[((i8 *)fp)[i]]; } \
else { for (usz i=0; i<n; i++) rp[i]=tab[((i16*)fp)[i]]; } \
decG(FOR); \
} \
TFREE(tab0);
typedef struct { B n, p; } B2;
static NOINLINE B toIntCell(B x, ux csz0, ur co) {
assert(TI(x,elType)!=el_B);
usz ria = shProd(SH(x), 0, co);
ShArr* rsh;
if (co>1) { rsh=m_shArr(co); shcpy(rsh->a,SH(x),co); }
B r0 = widenBitArr(x, co);
usz csz = shProd(SH(r0),co,RNK(r0)) << elWidthLogBits(TI(r0,elType));
u8 t;
if (csz==8) t = t_i8slice;
else if (csz==16) t = t_i16slice;
else if (csz==32) t = t_i32slice;
else if (csz==64) t = t_f64slice;
else UD;
TySlice* r = m_arr(sizeof(TySlice), t, ria);
r->p = a(r0);
r->a = tyany_ptr(r0);
if (co>=1) arr_shSetU((Arr*)r, co, rsh);
else arr_shVec((Arr*)r);
return taga(r);
}
static NOINLINE B cpyToElLog(B x, u8 xe, u8 lb) {
switch(lb) { default: UD;
case 0: return taga(cpyBitArr(x));
case 3: return taga(elNum(xe)? cpyI8Arr(x) : cpyC8Arr(x));
case 4: return taga(elNum(xe)? cpyI16Arr(x) : cpyC16Arr(x));
case 5: return taga(elNum(xe)? cpyI32Arr(x) : cpyC32Arr(x));
case 6: return taga(cpyF64Arr(x));
}
}
static NOINLINE B2 splitCells(B n, B p, u8 mode) { // 0:∊ 1:⊐ 2:⊒
#define SYMB (mode==0? "∊" : mode==1? "⊐" : "⊒")
#define ARG_N (mode? "𝕩" : "𝕨")
#define ARG_P (mode? "𝕨" : "𝕩")
if (isAtm(p) || RNK(p)==0) thrF("%U: %U cannot have rank 0", SYMB, ARG_P);
ur pr = RNK(p);
if (isAtm(n)) n = m_atomUnit(n);
ur nr = RNK(n);
if (nr < pr-1) thrF("%U: Rank of %U must be at least the cell rank of %U (%H ≡ ≢𝕨, %H ≡ ≢𝕩)", SYMB, ARG_N, ARG_P, mode?p:n, mode?n:p);
ur pcr = pr-1;
ur nco = nr-pcr;
if (nco>0 && eqShPart(SH(n)+nco, SH(p)+1, pcr)) {
u8 ne = TI(n,elType);
u8 pe = TI(p,elType);
if (ne<el_B && pe<el_B && elNum(ne)==elNum(pe)) {
usz csz = arr_csz(p);
u8 neb = elWidthLogBits(ne);
u8 peb = elWidthLogBits(pe);
u8 meb = neb>peb? neb : peb;
ux rb = csz<<meb;
if (rb!=0 && rb<=62) {
if (n.u == p.u) { decG(p); n=toIntCell(n,rb,1); return (B2){.n=n, .p=incG(n)}; }
if (neb!=meb) n = cpyToElLog(n, ne, meb);
else if (peb!=meb) p = cpyToElLog(p, pe, meb);
return (B2){.n=toIntCell(n,rb,nco), .p=toIntCell(p,rb,1)};
}
}
}
return (B2){.n=toKCells(n,nco), .p=toCells(p)};
#undef ARG_N
#undef ARG_P
#undef SYMB
}
static B reduceI32Width(B r, usz count) {
return count<=I8_MAX? taga(cpyI8Arr(r)) : count<=I16_MAX? taga(cpyI16Arr(r)) : r;
}
B indexOf_c2(B t, B w, B x) {
if (RARE(!isArr(w) || RNK(w)!=1)) {
B2 t = splitCells(x, w, 1);
w = t.p;
x = t.n;
}
if (!isArr(x) || RNK(x)==0) {
usz wia = IA(w);
B el = isArr(x)? IGetU(x,0) : x;
i32 res = wia;
u8 we = TI(w,elType);
if (we<el_B) {
void* wp = tyany_ptr(w);
u8 v8; u16 v16; u32 v32; f64 v64f;
switch(we) { default: UD;
case el_bit: if (!q_bit(el)) goto notfound; res = bit_find(wp,wia,o2bG(el)); goto checked;
case el_i8: if (!q_i8 (el)) goto notfound; v8 = ( u8)( i8)o2iG(el); goto chk8;
case el_i16: if (!q_i16(el)) goto notfound; v16 = (u16)(i16)o2iG(el); goto chk16;
case el_i32: if (!q_i32(el)) goto notfound; v32 = (u32)(i32)o2iG(el); goto chk32;
case el_f64: if (!q_f64(el)) goto notfound; v64f= o2fG(el); goto chk64f;
case el_c8: if (!q_c8 (el)) goto notfound; v8 = ( u8) o2cG(el); goto chk8;
case el_c16: if (!q_c16(el)) goto notfound; v16 = (u16) o2cG(el); goto chk16;
case el_c32: if (!q_c32(el)) goto notfound; v32 = (u32) o2cG(el); goto chk32;
}
chk8: for (usz i = 0; i < wia; i++) if ((( u8*)wp)[i]== v8 ) { res=i; break; } goto checked;
chk16: for (usz i = 0; i < wia; i++) if (((u16*)wp)[i]==v16 ) { res=i; break; } goto checked;
chk32: for (usz i = 0; i < wia; i++) if (((u32*)wp)[i]==v32 ) { res=i; break; } goto checked;
chk64f: for (usz i = 0; i < wia; i++) if (((f64*)wp)[i]==v64f) { res=i; break; } goto checked;
} else {
SGetU(w)
for (usz i = 0; i < wia; i++) if (equal(GetU(w,i), el)) { res = i; goto checked; }
}
checked:; notfound:;
decG(w); dec(x);
B r = m_vec1(m_f64(res));
arr_shAtm(a(r)); // replaces shape
return r;
} else {
u8 we = TI(w,elType); usz wia = IA(w);
u8 xe = TI(x,elType); usz xia = IA(x);
if (wia == 0) { B r=taga(arr_shCopy(allZeroes(xia), x)); decG(w); decG(x); return r; }
if (elNum(we) && elNum(xe)) { tyEls:
if (we==el_bit) {
u64* wp = bitarr_ptr(w);
u64 w0 = 1 & wp[0];
u64 i = bit_find(wp, wia, !w0); decG(w);
if (i!=wia) incG(x);
B r = C2i(mul, wia , C2i(ne, w0, x)) ;
return i==wia? r : C2(sub, r, C2i(mul, wia-i, C2i(eq, !w0, x)));
}
if (wia<=(we<=el_i16?4:16) && xia>16) {
SGetU(w);
#define XEQ(I) C2(ne, GetU(w,I), incG(x))
B r = XEQ(wia-1);
for (usz i=wia-1; i--; ) r = C2(mul, XEQ(i), C2i(add, 1, r));
#undef XEQ
decG(w); decG(x); return r;
}
if (xia+wia>20 && we<=el_i16 && xe<=el_i16) {
B r;
TABLE(w, x, i32, wia, i)
return reduceI32Width(r, wia);
}
} else if (elChr(we) && elChr(xe)) { we-= el_c8-el_i8; xe-= el_c8-el_i8; goto tyEls; }
i32* rp; B r = m_i32arrc(&rp, x);
H_b2i* map = m_b2i(64);
SGetU(x)
SGetU(w)
for (usz i = 0; i < wia; i++) {
bool had; u64 p = mk_b2i(&map, GetU(w,i), &had);
if (!had) map->a[p].val = i;
}
for (usz i = 0; i < xia; i++) rp[i] = getD_b2i(map, GetU(x,i), wia);
free_b2i(map); decG(w); decG(x);
return reduceI32Width(r, wia);
}
}
B enclosed_0, enclosed_1;
B memberOf_c2(B t, B w, B x) {
if (isAtm(x) || RNK(x)!=1) {
B2 t = splitCells(w, x, false);
w = t.n;
x = t.p;
}
if (isAtm(w)) goto single;
ur wr = RNK(w);
if (wr>0) goto many;
B w0 = IGet(w, 0);
dec(w);
w = w0;
goto single;
B r;
single: {
usz xia = IA(x);
SGetU(x)
for (usz i = 0; i < xia; i++) if (equal(GetU(x, i), w)) { r = incG(enclosed_1); goto dec_wx; }
r = incG(enclosed_0);
dec_wx:; dec(w);
goto dec_x;
}
many: {
u8 we = TI(w,elType); usz wia = IA(w);
u8 xe = TI(x,elType); usz xia = IA(x);
if (xia == 0) { r=taga(arr_shCopy(allZeroes(wia), w)); decG(w); goto dec_x; }
if (elNum(we) && elNum(xe)) { tyEls:
#define WEQ(V) C2(eq, incG(w), V)
if (xe==el_bit) {
u64* xp = bitarr_ptr(x);
u64 x0 = 1 & xp[0];
r = WEQ(m_usz(x0));
if (bit_has(xp, xia, !x0)) r = C2(or, r, WEQ(m_usz(!x0)));
decG(w); goto dec_x;
}
if (xia<=(xe==el_i16?8:16) && wia>16) {
SGetU(x);
r = WEQ(GetU(x,0));
for (usz i=1; i<xia; i++) r = C2(or, r, WEQ(GetU(x,i)));
decG(w); goto dec_x;
}
#undef WEQ
if (xia+wia>20 && we<=el_i16 && xe<=el_i16) {
B r;
TABLE(x, w, i8, 0, 1)
return taga(cpyBitArr(r));
}
} else if (elChr(we) && elChr(xe)) { we-= el_c8-el_i8; xe-= el_c8-el_i8; goto tyEls; }
H_Sb* set = m_Sb(64);
SGetU(x) SGetU(w)
bool had;
for (usz i = 0; i < xia; i++) mk_Sb(&set, GetU(x,i), &had);
u64* rp; r = m_bitarrc(&rp, w);
for (usz i = 0; i < wia; i++) bitp_set(rp, i, has_Sb(set, GetU(w,i)));
free_Sb(set); decG(w);
goto dec_x;
}
dec_x:;
decG(x);
return r;
}
B count_c2(B t, B w, B x) {
if (RARE(!isArr(w) || RNK(w)!=1)) {
B2 t = splitCells(x, w, 2);
w = t.p;
x = t.n;
}
if (!isArr(x) || IA(x)<=1) return C2(indexOf, w, x);
u8 we = TI(w,elType); usz wia = IA(w);
u8 xe = TI(x,elType); usz xia = IA(x);
i32* rp; B r = m_i32arrc(&rp, x);
TALLOC(usz, wnext, wia+1);
wnext[wia] = wia;
if (we<=el_i16 && xe<=el_i16) {
if (we==el_bit) { w = toI8Any(w); we = TI(w,elType); }
if (xe==el_bit) { x = toI8Any(x); xe = TI(x,elType); }
el8or16:;
usz it = elRange(we); // Range of writes
usz ft = elRange(xe); // Range of lookups
usz t = it>ft? it : ft; // Table allocation width
TALLOC(i32, tab0, t); i32* tab = tab0 + t/2;
usz m=wia, n=xia;
void* ip = tyany_ptr(w);
void* fp = tyany_ptr(x);
// Initialize
if (xe==el_i16 && n<ft/(64/sizeof(i32)))
{ for (usz i=0; i<n; i++) tab[((i16*)fp)[i]]=wia; }
else { for (i64 i=0; i<ft; i++) tab[i-ft/2]=wia; }
// Set
#define SET(T) for (usz i=m; i--; ) { i32* p=tab+((T*)ip)[i]; wnext[i]=*p; *p=i; }
if (we==el_i8) { SET(i8) } else { SET(i16) }
#undef SET
// Lookup
#define GET(T) for (usz i=0; i<n; i++) { i32* p=tab+((T*)fp)[i]; *p=wnext[rp[i]=*p]; }
if (xe==el_i8) { GET(i8) } else { GET(i16) }
#undef GET
TFREE(tab0);
} else if (we>=el_c8 && we<=el_c16 && xe>=el_c8 && xe<=el_c16) {
we-= el_c8-el_i8; xe-= el_c8-el_i8;
goto el8or16;
} else {
H_b2i* map = m_b2i(64);
SGetU(x)
SGetU(w)
for (usz i = wia; i--; ) {
bool had; u64 p = mk_b2i(&map, GetU(w,i), &had);
wnext[i] = had ? map->a[p].val : wia;
map->a[p].val = i;
}
for (usz i = 0; i < xia; i++) {
bool had; u64 p = getQ_b2i(map, GetU(x,i), &had);
usz j = wia;
if (had) { j = map->a[p].val; map->a[p].val = wnext[j]; }
rp[i] = j;
}
free_b2i(map);
}
TFREE(wnext); decG(w); decG(x);
return reduceI32Width(r, wia);
}
void search_init(void) {
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAtm(a); *p= 0; gc_add(enclosed_0=taga(a)); }
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAtm(a); *p=~0ULL; gc_add(enclosed_1=taga(a)); }
}