uCBQN/src/singeli/src/search.singeli
2024-06-18 07:46:31 -04:00

652 lines
20 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

include './base'
include './mask'
include './vecfold'
include './hashtab'
def findFirst{C, M, F, ...v1} = {
def exit = makelabel{}
def args = undef{M{...each{select{., 0}, v1}}}
def am = length{select{v1,0}}
each{{last, ...v2} => {
if (last or C{...v2}) {
each{=, args, M{...v2}}
goto{exit}
}
}, iota{am} == am-1, ...v1}
unreachable{}
setlabel{exit}
F{...args}
}
def search{E, x, n:(u64), OP} = {
def bulk = arch_defvw/width{E}
def VT = [bulk]E
def end = makeBranch{
tup{u64, ty_u{VT}},
{i,c} => return{i*bulk + promote{u64, ctzX{homMaskX{c}}}}
}
@muLoop{bulk, tern{arch_defvw>=256, 1, 2}}(x in tup{VT,*E~~x}, M in 'm' over is to n) {
eq:= each{OP, x}
if (homAny{M{tree_fold{|, eq}}}) {
findFirst{
{i,c} => homAny{c},
{i,c} => tup{i,c},
end,
is, eq
}
}
}
n
}
fn searchOne{A, E}(x:*void, e0:A, len:u64) : u64 = {
def e = if (A==E) e0 else cast_i{E, e0}
search{E, x, len, {c:VT} => c == VT**e}
}
def isNegZero{x:T} = re_el{u64,x} == re_el{u64, T ** -f64~~0}
fn searchNormalizable{}(x:*f64, len:u64) : u64 = {
search{f64, x, len, {c:VT} => isNegZero{c} | (c!=c)}
}
fn copyOrdered{}(r:*f64, x:*f64, len:u64) : u1 = {
def E = f64
def bulk = arch_defvw/width{E}
def VT = [bulk]E
@maskedLoop{bulk}(x in tup{VT,x}, sr in tup{'g',r}, M in 'm' over i to len) {
if (homAny{M{x!=x}}) return{1}
sr{x + VT**0}
}
0
}
(if (has_simd) {
export{'simd_search_u8', searchOne{u64, u8}}
export{'simd_search_u16', searchOne{u64, u16}}
export{'simd_search_u32', searchOne{u64, u32}}
export{'simd_search_f64', searchOne{f64, f64}}
export{'simd_search_normalizable', searchNormalizable{}}
export{'simd_copy_ordered', copyOrdered{}}
})
# In-register bit table
def arch_vec{T} = [arch_defvw/width{T}]T
def TI = i8 # Table values
def VI = arch_vec{TI}
def simd_bittab = hasarch{'SSSE3'}
def bittab_init{tab, z} = {
@for (t in *TI~~tab over 256) t = z
}
def bittab_init{tab, z if simd_bittab} = {
init:= VI**z
@unroll (t in *VI~~tab over 256/vcount{VI}) t = init
}
def bittab_selector{loadtab} = {
def [nv]_ = VI
{t0, t1}:= loadtab{}
low:= VI**7
hi4:= VI**(-(1<<4))
b := VI~~make{[nv]u8, 1 << (iota{nv} & 7)}
def selector{x} = {
top := hi4 + VI~~((arch_vec{u32}~~(x&~low))>>3)
byte:= sel{[16]i8, t0, hi4^top} | sel{[16]i8, t1, top}
mask:= sel{[16]i8, b, x & low}
homMask{(mask & byte) == mask}
}
def reload{} = { tup{t0,t1} = loadtab{} }
tup{selector, reload}
}
def readbytes{vtab}{} = {
def [k]_ = VI; def l = 128/k
def side{i} = {
def U = arch_vec{ty_u{k}}
def m = @collect (vtab over _ from i to i+l) homMask{vtab}
VI~~make{U, if (vcount{U}>l) merge{m,m} else m}
}
each{side, l*iota{2}}
}
# Look up bits from table
def bittab_lookup{x0:(*void), n:(u64), r0:(*void), tab:(*void)} = {
x:= *u8~~x0
t:= *TI~~tab
r:= *u64~~r0
rem:= n
while (rem > 0) {
k:= rem; if (k>64) k=64
rv:u64 = 0
@for (x over j to k) rv|= u64~~promote{i64,load{t,x}} & ((u64~~1)<<j)
store{r, 0, rv}
x+=k; rem-=k; ++r
}
}
def bittab_lookup{x0:(*void), n:(u64), r0:(*void), tab:(*void) if simd_bittab} = {
def {bitsel, _} = bittab_selector{readbytes{*VI~~tab}}
def [k]_ = VI
@for (x in *VI~~x0, r in *ty_u{k}~~r0 over cdiv{n,k}) r = bitsel{x}
}
# Fill table with t (0 or -1) at all bytes in x0
# Stop early if the sum u reaches 0, indicating all bytes in the table
# are equal: by the time it's checked at least one has been set to t,
# so they're all t
# Fill r0 depending on mode:
# - 'none': ignore
# - 'mask': Mark Firsts of x0
# - 'unique': Deduplicate of x0
# - 'index': First index of value x at r0+x
def do_bittab{x0:(*void), n:(u64), tab:(*void), u:(u8), t, mode, r0} = {
def rbit = mode == 'mask'
def rval = mode == 'unique'
def rind = mode == 'index'
def storebit{i, v:T} = if (rbit) store{*T~~r0, i, v}
if (rbit or rval) assert{same{t,0}}
btab:= *i8~~tab
def settab_sub{x, v, i} = {
if (rval) store{*u8~~r0, u, x}
if (rind and v!=t) store{r0, x, cast_i{usz, i}}
u+= u8~~i8~~(t - v) # u tracks the total of btab
store{btab, x, t}
v
}
def settab1{x, i} = settab_sub{x, -1 - t, i} # Known new
def settab{x, i} = settab_sub{x, load{btab, x}, i} # General case
def settab{T, x, i} = T~~promote{ty_s{T}, settab{x, i}}
x:= *u8~~x0
if (not simd_bittab) {
rem:= n
@for (i to cdiv{n,64}) {
k:= rem; if (k>64) k=64
rw:u64 = 0
@for (x over j to k) {
new:= settab{u64, x, i*64+j} # Index usually unused
if (rbit) rw|= new & ((u64~~1)<<j)
}
storebit{i, rw}
x+=k; rem-=k
}
} else {
# Do first few values with a scalar loop
# Avoids the cost of ever loading the table into vectors for n<=48
i:u64 = 32; if (n<=48) i=n
def [k]_ = VI; def uk = ty_u{k}; def ik = ty_s{k}
{rw,rv} := undef{tup{u64,uk}} # Bit results, used if rbit
if (rbit) rw = 0
@for (x over j to i) {
new:= settab{u64, x, j}
if (rbit) rw|= new & ((u64~~1)<<j)
}
storebit{0, rw}
if ((mode == 'none' or rind) and u == 0) return{u} # Won't ever trigger (m != 0)!
def done = makelabel{}
def {bitsel, reload_tab} = bittab_selector{readbytes{*VI~~tab}}
xv:= *VI~~x0
while (i < n) {
i0:= i; iw:= i/k
v:= load{xv, iw}
m:= bitsel{v} # Mask of possibly-new values
if (not same{t,0}) m^= uk~~promote{ik, t}
i+= k
if (i > n) m&= (~uk~~0)>>((-n)%k)
# Any new values?
if (m == 0) {
storebit{iw, m}
} else {
# Add values to the table and filter m
if (rbit) rv = m
im:= i0 + ctz{m}
xi:= load{x, im}
settab1{xi, im}
if ((m&(m-1)) != 0) { # More bits than one
# Filter out values equal to the previous, or first new
def pind = (iota{k}&15) - 1
prev:= make{VI, each{max{0, .}, pind}}
e:= ~homMask{v == VI**TI~~xi}
e&= base{2,pind<0} | ~homMask{v == sel{[16]i8, v, prev}}
if (rbit) rv&= e | -m # Don't remove first bit
m&= e
while (m != 0) {
im:= i0 + ctz{m}
new:= settab{uk, load{x, im}, im}
m1:= m-1; m&= m1 # Clear low bit
if (rbit) rv&= m1 | new # Clear if not new
}
}
storebit{iw, rv}
if (u == 0) { # All bytes seen
if (rbit) @for (r in *uk~~r0 over _ from iw+1 to cdiv{n,k}) r = 0
goto{done}
}
reload_tab{}
}
}
setlabel{done}
}
# When u==0 isn't immediately tested, first result can be overwritten
if (rval) store{*u8~~r0, 0, load{*u8~~x0}}
u
}
fn simd_mark_firsts_u8(x0:*void, n:u64, r0:*void, tab:*void) : void = {
bittab_init{tab, -1}
u:u8 = 0
do_bittab{x0, n, tab, u, 0, 'mask', r0}
}
fn simd_deduplicate_u8(x0:*void, n:u64, r0:*void, tab:*void) : u64 = {
assert{n != 0}
bittab_init{tab, -1}
u:u8 = 0
do_bittab{x0, n, tab, u, 0, 'unique', r0}
1 + promote{u64, u-1} # 0 to 256
}
fn fill_bittab(x0:*void, n:u64, tab:*void, u:u8, t:i8) : u8 = {
do_bittab{x0, n, tab, u, t, 'none', 0}
}
fn simd_member_u8(w0:*void, nw:u64, x0:*void, nx:u64, r0:*void, tab:*void) : void = {
assert{nw > 0}
rev:u1 = nx < nw/4 # Reverse lookup
bittab_init{tab, -promote{i8,rev}}
u:u8 = 0 # Sum of table, either 0 or 256
if (rev) u = fill_bittab(x0, nx, tab, u, 0)
u = fill_bittab(w0, nw, tab, u, -1)
if (u == 0) { # All found!
@for (r in *u64~~r0 over cdiv{nx,64}) r = maxvalue{u64}
} else {
bittab_lookup{x0, nx, r0, tab}
}
}
fn simd_index_tab_u8{I}(w0:*void, nw:u64, x0:*void, nx:u64, tab:*void, i0:*void) : void = {
rev:u1 = nx < nw/4
bittab_init{tab, -promote{i8,rev}}
ind:= *I~~i0
@for (ind over 256) ind = cast_i{I, nw}
u:u8 = 0
if (rev) u = fill_bittab(x0, nx, tab, u, 0)
do_bittab{w0, nw, tab, u, -1, 'index', ind}
}
export{'simd_mark_firsts_u8', simd_mark_firsts_u8}
export{'simd_deduplicate_u8', simd_deduplicate_u8}
export{'simd_member_u8', simd_member_u8}
export{'simd_index_tab_u8', simd_index_tab_u8{usz}}
def acc{unr, init:T} = {
a0v := init
def a0 = tup{a0v}
def a1 = @collect(unr) { reg:=init }
def op{S=='get'} = a0v
def op{S=='tr', F} = { a0v = tree_fold{F, a1} }
def op{S=='upd', is, F} = {
if (length{is}==1) a0 = F{a0}
else a1 = F{a1}
}
}
def isI64{x:T=[_](f64) if hasarch{'AARCH64'}} = x == cvt{f64, cvt{i64, x}}
def isI64{x:T=[_](f64) if hasarch{'SSE4.1'}} = (x==floor{x}) & (abs{x}<=T**(1<<53))
def maskBlend{b:T, x:T, M} = x
def maskBlend{b:T, x:T, M if M{0}} = homBlend{b, x, M{T, 'to homogeneous bits'}}
fn getRange{E}(x0:*void, res:*i64, n:u64) : u1 = {
assert{n>0}
x:= *E~~x0
min1:E = *x
max1:E = *x
if (has_simd and (not E==f64 or hasarch{'AARCH64'} or hasarch{'SSE4.1'})) {
def bulk = arch_defvw/width{E}
def VT = [bulk]E
def unr = tern{E==f64 and hasarch{'X86_64'}, 1, 2}
def minA = acc{2, VT**min1}
def maxA = acc{2, VT**min1}
@muLoop{bulk, unr, {} => { minA{'tr',min}; maxA{'tr',max} }}(cx in tup{VT,x}, M in 'm' over is to n) {
if (E==f64 and homAny{M{tree_fold{|, each{{c} => ~isI64{c}, cx}}}}) return{0}
minA{'upd', is, {a} => eachx{maskBlend, a, each{min, a, cx}, M}} # blend
maxA{'upd', is, {a} => eachx{maskBlend, a, each{max, a, cx}, M}} # blend
}
min1 = vfold{min, minA{'get'}}
max1 = vfold{max, maxA{'get'}}
} else {
@for (x over i to n) {
if (E==f64 and rare{x != emit{f64, '', emit{i64, '', x}}}) return{0}
min1 = min{min1, x}
max1 = max{max1, x}
}
}
store{res, 0, cast_i{i64, min1}}
store{res, 1, cast_i{i64, max1}}
1
}
exportT{'simd_getRangeRaw', each{getRange, tup{i8,i16,i32,f64}}}
# Hash tables
def rty{name} = if (to_prim{name}=='∊') i8 else i32
def ity{name} = (to_prim{name}=='⊒')**(*u32)
fn hashtab{T, name}(rpi:*rty{name}, iv:*void, mi:usz, fv:*void, ni:usz, links:ity{name}) = {
# iv,mi/ip,m - searched-in; fv,ni/fp,n - searched-for; may get swapped around & back
def prim = to_prim{name}
def U = if (prim=='∊') usz else u32
m := cast_i{U,mi}; n := cast_i{U,ni}
def wt = width{T}
ip := *T~~iv; fp := *T~~fv; rp := if (prim=='∊') rpi else *u32~~rpi
def swap_sides{} = each{{a,b}=>{t:=a; a=b; b=t}, tup{ip,m}, tup{fp,n}}
swap:u1 = n+(1024*(prim!='⊒')) < (if (prim!='∊') m else m-m/4)
if (swap) swap_sides{}
log := clzc{m}
# Max size
msl := max{clzc{(m+m/2)|4}+1, min{ux~~14,clzc{m+n/4}}}
msz := usz~~1 << msl
# Starting log-size (try_vec_memb requires size>4)
sl := msl; if (msl>=14) sl = 12+(msl&1)
b:U = 64 # Block size
# Filling e slots past the end requires e*(e+1)/2 collisions, so
# m entries with <2 each can fill <sqrt(4*m)
def cc_stop = 2*cast_i{u64,m}
ext := promote{usz, tern{m<=b, max{m,U~~4}, b + (U~~1 << (log/2 + 1))}}
maxh := T~~maxvalue{T}
def aux = prim!='∊'
def {tabs, sz, sh, div_thresh, hash_resize, hash_free} = hash_alloc{
sl, msz, ext, tup{T, ...aux**u32}, tup{maxh, ...aux**'any'}, 0, 1
}
def {hash,...vals} = tabs
def abort = makelabel{}
i:U = 0 # Saved to determine if hashing finished
def insert_all{set_tab, set_maxh, dup, ...uniq} = {
cc:u64 = 0 # Collision counter
while (i < m) {
e := tern{m-i>b, i+b, m}
while (i < e) {
def ii = if (prim!='⊒') i else m-i-1
h := hash_val{load{ip,ii}}; j := h>>sh
set_maxh{h==maxh, i, j}
kv := each{load{.,j}, tabs}; def {k,...kr} = kv
# Robin Hood insertion
j0 := j; je := j # Save value; end of chain (insert at j)
if (k != maxh) {
if (k == h) goto{dup}
do {
++je; knv := each{load{.,je}, tabs}; def {kn,..._} = knv
def c = promote{T, h >= k}
j += c
if (kn == h) goto{dup}
each{store{.,je-c,.}, tabs, kv}
each{=, kv, knv}
} while (k != maxh)
cc += cast_i{u64, je-j0}
}
each{{u} => { u += promote{usz,h!=maxh} }, uniq}
store{hash, j, h}
set_tab{j, h}
++i
}
# Check collision counter and possibly resize
def p64 = promote{u64,.}
dc := p64{cc} - p64{div_thresh{i}}
if (tern{i<m, i64~~dc>=0, sz<msz}) {
if (sz == msz) goto{abort}
rdc := p64{m-i}*dc # 0 if i==m, no need to recompute
mm := p64{i}*p64{m+i}
def recheck = setlabel{}
if (cc>=cc_stop or p64{n/4}*p64{cc} + rdc >= mm>>(5+log-(wt-sh))) {
hash_resize{cc, 2} # Factor of 4
if (i==m and sz<msz) goto{recheck}
if (cc >= cc_stop) { i=0; goto{abort} }
}
}
}
}
def get_end{} = {
end := maxh>>sh
while (load{hash,end}!=maxh) ++end
end
}
def sequester_maxh{j, found} = { j += cast_i{T, ext &- found} }
def unsequester_maxh{tab} = {
store{tab, get_end{}, load{tab, maxh>>sh + cast_i{T,ext}}}
}
def hash_remove{j,h} = {
do {
jp:=j; ++j
h=load{hash,j}
if (h>>sh == j) h = maxh
store{hash, jp, h}
each{{t} => store{t, jp, load{t,j}}, vals}
} while (h!=maxh)
}
def memb_remove{uniq, has_maxh} = {
@for (ip over m) {
h := hash_val{ip}; j := h>>sh
def shortcut = makelabel{}
if (h == maxh) {
if (uniq==0) goto{shortcut}
has_maxh = 0
} else {
k := load{hash,j}
if (k <= h) {
while (k < h) { ++j; k = load{hash,j} }
if (k == h) {
--uniq
if (uniq==0 and not has_maxh) {
setlabel{shortcut}
@for (rp over n) rp = 1
goto{abort}
}
hash_remove{j,h}
}
}
}
}
}
def ind_rev_lookup{uniq, has_maxh} = {
def shortcut = makelabel{}
@for (ip over i to m) {
h := hash_val{ip}; j := h>>sh
k := load{hash,j}
if (k <= h) {
while (k < h) { ++j; k = load{hash,j} }
def had_maxh{} = { s:=has_maxh; has_maxh=0; s }
if (k == h and (h<maxh or had_maxh{})) {
def {inds} = vals
store{rp, load{inds, j}, i}
--uniq; if (uniq==0) goto{shortcut}
hash_remove{j,h}
}
}
}
setlabel{shortcut}
}
def prog_lookup{swap} = { # Progressive Index-of lookup
def rev{a,b} = if (swap) tup{b,a} else tup{a,b}
memset{*u32~~rp, ...rev{m,n}}
c := m; def shortcut = makelabel{}
@for (fp over i to n) {
h := hash_val{fp}; j := h>>sh
k := load{hash,j}
if (k <= h) {
while (k < h) { ++j; k = load{hash,j} }
if (k==h) {
def {inds} = vals; def {link} = links
ti := load{inds, j}
if (ti > 0) {
store{rp, ...rev{i, m-ti}}
--c; if (c==0) goto{shortcut}
ti = load{link, ti}
if (ti > 0 or h==maxh) store{inds, j, ti}
else hash_remove{j,h}
}
}
}
}
setlabel{shortcut}
}
def lookup_all{get_res} = {
@for (rp, fp over n) {
h := hash_val{fp}; j := h>>sh
k := undefined{T}; while ((k=load{hash,j}) < h) ++j
rp = get_res{k==h, j}
}
}
match (prim, ...vals, ...links) {
{('∊')} => {
has_maxh:u1 = 0
uniq:usz = 0 # Uniques inserted
def dup = makelabel{}
insert_all{
{j, h} => setlabel{dup},
{found, i, j} => has_maxh |= found,
dup, uniq
}
if (swap) {
swap_sides{}; i=m # i==m return value is kind of dumb
memb_remove{uniq,has_maxh} # Remove values in ip from hash
}
try_vec_memb{T, hash, sz, sh, maxh, has_maxh, swap, rp, fp, n, abort}
end := get_end{} # Clip trailing maxh if it shouldn't be in the table
if (has_maxh) ++end # Don't clip
lookup_all{{found, j} => promote{i8, swap ^ (found & (j<end))}}
}
{('⊐'), inds} => if (not swap) {
has_maxh:u1 = 0
ind_maxh:u32 = 0
def dup = makelabel{}
def set{j, h} = {
store{inds, j, m-i} # So it can be cleared with one &- in get{}
setlabel{dup}
}
def set_maxh{found, i, j} = {
ind_maxh |= i &- (found&~has_maxh)
has_maxh |= found
}
insert_all{set, set_maxh, dup}
store{inds, get_end{}, (m-ind_maxh) &- has_maxh}
lookup_all{{found, j} => m - (load{inds, j} &- found)}
} else { # swap
# After insert_all, position i in rp contains:
# - ≠𝕨, if i is the first occurrence of its value in 𝕩, or
# - j-≠𝕩, where j<i is the index of the first occurrence
uniq:usz = 0
has_maxh:u1 = 0
ri:u32 = 0 # Placed in rp (should be scoped to insert_all loop body)
def dup = makelabel{}
def set_maxh{found, i, j} = {
sequester_maxh{j, found}
ri = cast_i{u32, n} # Initialize to not-found
if (found & has_maxh) goto{dup}
has_maxh |= found
}
def set{j, h} = {
store{inds, j, i}
if (u1~~0) {
setlabel{dup}
ri = load{inds, j} - m
}
store{rp, i, ri}
}
insert_all{set, set_maxh, dup, uniq}
# Lookup places correct result index at each first occurrence
swap_sides{}; i=m
if (has_maxh) { ++uniq; unsequester_maxh{inds} }
ind_rev_lookup{uniq, has_maxh}
# Propagate to later occurrences
@for (r in rp over i to n) r = load{rp, min{i, u32~~r+n}}
}
{('⊒'), inds, link} => {
store{link,0,0}
store{inds, maxh>>sh + cast_i{T,ext}, 0}
def dup = makelabel{}
def set{j, h} = {
store{inds, j, load{inds,j} &- (h==maxh)}
setlabel{dup}
i1 := i+1
store{link, i1, load{inds,j}}
store{inds, j, i1}
}
def set_maxh{found, i, j} = sequester_maxh{j, found}
insert_all{set, set_maxh, dup}
unsequester_maxh{inds}
if (not swap) prog_lookup{0} else prog_lookup{1}
}
}
setlabel{abort}
hash_free{}
i == m # Whether it finished
}
def try_vec_memb{..._} = {}
def try_vec_memb{T==u32, hash, sz, sh, maxh, has_maxh, swap, rp, fp, n, done
if hasarch{'SSE4.2'}} = {
# Hash h wants bin h>>sh, so the offset for h in slot i is (in infite-precision ints)
# i-h>>sh = i+((1<<sh-1)-h)>>sh = (((i+1)<<sh-1)-h)>>sh
# We maintain io = (i+1)<<sh-1
def vl = 4; def V = [vl]T
assert{sz > 4}
assert{sz%vl == 0}
io := make{V, each{{k} => T~~k<<sh - 1, 1+iota{vl}}}
id := V**(T~~vl<<sh)
mv := V**0
@for (h in *V~~hash over sz/vl) { mv=max{mv, io-min{h,io}}; io+=id }
max_off := vfold{max, mv} >> sh
# sz==1<<sh, so when i>=sz the above overflows
# And we have to handle maxh specially anyway
def mx{i,h} = { max_off = max{max_off, i-h>>sh} }
i:=cast_i{T,sz}; h:T=i; while ((h=load{hash,i})<maxh) { mx{i,h}; ++i }
if (has_maxh) mx{i,maxh}
vswap := base{256,vl**1} * promote{T,swap}
def memb{test} = {
def R = i8; def rw = width{R}; def u = width{T}/rw
l := n/u
@for (r in *T~~rp over i to l) {
c := V**0 # Will combine u results to avoid folding too much
@unroll (f in fp+u*i over a to u) c |= V**(1<<(rw*a)) & test{f}
r = vswap ^ vfold{|, c}
}
@for (rp, fp over _ from u*l to n) rp = promote{R, swap ^ homAny{test{fp}}}
goto{done}
}
def try{nv} = {
if (max_off < nv*vl) {
# Avoid matching maxh if it shouldn't be in the table
clear := V**(maxh &- ~has_maxh)
@for (hv in *V~~(hash + maxh>>sh) over vl) hv = hv &~ (hv == clear)
# Test against nv vectors
def test{x} = {
h := hash_val{x}; vh := V**h
def any{{...r,a}} = any{r}|a; def any{{a}}=a
any{@collect (k in *V~~(hash+h>>sh) over nv) vh == k}
}
memb{test}
}
}
each{try, tup{1,2}}
}
def exp_hash{name} = {
exportT{merge{'si_',name,'_c2_hash'}, each{hashtab{., name}, tup{u32,u64}}}
}
each{exp_hash, names}