72 lines
2.3 KiB
Plaintext
72 lines
2.3 KiB
Plaintext
include './base'
|
||
include './vecfold'
|
||
|
||
if_inline (hasarch{'SSE2'}) {
|
||
fn sum_vec{T}(v:T) = vfold{+, fold{+, mzip128{v, T**0}}}
|
||
def fold_addw{v:T=[_](u8)} = sum_vec{T}(v)
|
||
}
|
||
|
||
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
|
||
def inc{ptr, ind} = inc{ptr, ind, 1}
|
||
|
||
# Write counts /⁼x to tab and return ⌈´x
|
||
fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
|
||
def vbits = arch_defvw
|
||
def vec = vbits/width{T}
|
||
def uT = ty_u{T}
|
||
def V = [vec]T
|
||
def block = (2048*8) / vbits # Target vectors per block
|
||
def b_max = block + block/4 # Last block max length
|
||
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
||
mx:T = min_allowed # Maximum of x
|
||
i:u64 = 0
|
||
while (i < n) {
|
||
# Number of elements to handle in this iteration
|
||
r:u64 = n - i; if (r > vec*b_max) r = vec*block
|
||
b := r / vec # Vector case does b full vectors if it runs
|
||
rv:= b * vec
|
||
r0:u64 = 0 # Elements actually handled by vector case
|
||
|
||
# Find range to check for suitability; return a negative if found
|
||
xv := *V~~x
|
||
jv := load{xv}; mv := jv
|
||
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
|
||
@for (x over _ from rv to r) { if (x<min_allowed) return{x}; if (x>mx) mx=x }
|
||
jt := vfold{min, jv}
|
||
mt := vfold{max, mv}
|
||
if (jt < min_allowed) return{jt}
|
||
if (mt > mx) mx = mt
|
||
|
||
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
|
||
if (nc <= 24*vbits/128) {
|
||
r0 = rv
|
||
j0 := promote{u64, uT~~jt} # Starting count
|
||
m := promote{u64, nc} # Number of iterations
|
||
total := trunc{usz, r0} # To compute last count
|
||
def count_each{js, num} = {
|
||
j := @collect (k to num) trunc{T, js+k}
|
||
c := copy{length{j}, [vec]uT ** 0}
|
||
e := each{{j}=>V**j, j}
|
||
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
|
||
def add_sum{c, j} = {
|
||
s := promote{usz, fold_addw{c}}
|
||
total -= s; inc{tab, j, s}
|
||
}
|
||
each{add_sum, c, j}
|
||
}
|
||
m4 := m / 4
|
||
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
||
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
||
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
|
||
}
|
||
|
||
# Scalar fallback and cleanup
|
||
@for (x over _ from r0 to r) inc{tab, x}
|
||
i += r
|
||
x += r
|
||
}
|
||
mx
|
||
}
|
||
|
||
export{'avx2_count_i8', count{i8}}
|