uCBQN/src/singeli/src/count.singeli
2024-11-14 22:45:43 +02:00

72 lines
2.3 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

include './base'
include './vecfold'
if_inline (hasarch{'SSE2'}) {
fn sum_vec{T}(v:T) = vfold{+, fold{+, mzip128{v, T**0}}}
def fold_addw{v:T=[_](u8)} = sum_vec{T}(v)
}
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
def inc{ptr, ind} = inc{ptr, ind, 1}
# Write counts /⁼x to tab and return ⌈´x
fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
def vbits = arch_defvw
def vec = vbits/width{T}
def uT = ty_u{T}
def V = [vec]T
def block = (2048*8) / vbits # Target vectors per block
def b_max = block + block/4 # Last block max length
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
mx:T = min_allowed # Maximum of x
i:u64 = 0
while (i < n) {
# Number of elements to handle in this iteration
r:u64 = n - i; if (r > vec*b_max) r = vec*block
b := r / vec # Vector case does b full vectors if it runs
rv:= b * vec
r0:u64 = 0 # Elements actually handled by vector case
# Find range to check for suitability; return a negative if found
xv := *V~~x
jv := load{xv}; mv := jv
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
@for (x over _ from rv to r) { if (x<min_allowed) return{x}; if (x>mx) mx=x }
jt := vfold{min, jv}
mt := vfold{max, mv}
if (jt < min_allowed) return{jt}
if (mt > mx) mx = mt
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
if (nc <= 24*vbits/128) {
r0 = rv
j0 := promote{u64, uT~~jt} # Starting count
m := promote{u64, nc} # Number of iterations
total := trunc{usz, r0} # To compute last count
def count_each{js, num} = {
j := @collect (k to num) trunc{T, js+k}
c := copy{length{j}, [vec]uT ** 0}
e := each{{j}=>V**j, j}
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
def add_sum{c, j} = {
s := promote{usz, fold_addw{c}}
total -= s; inc{tab, j, s}
}
each{add_sum, c, j}
}
m4 := m / 4
@for (j4 to m4) count_each{j0 + 4*j4, 4}
@for (j from 4*m4 to m) count_each{j0 + j, 1}
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
}
# Scalar fallback and cleanup
@for (x over _ from r0 to r) inc{tab, x}
i += r
x += r
}
mx
}
export{'avx2_count_i8', count{i8}}