Run-based 1-byte /⁼ implementation
This commit is contained in:
parent
092ba4167a
commit
e681f3c09a
@ -28,36 +28,31 @@ fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
|
|||||||
r0:u64 = 0 # Elements actually handled by vector case
|
r0:u64 = 0 # Elements actually handled by vector case
|
||||||
|
|
||||||
# Find range to check for suitability; return a negative if found
|
# Find range to check for suitability; return a negative if found
|
||||||
|
# Also record number of differences dc
|
||||||
|
# (double-counts at index vec but it doesn't need to be exact)
|
||||||
xv := *V~~x
|
xv := *V~~x
|
||||||
jv := load{xv}; mv := jv
|
jv := load{xv}; mv := jv; dc := -(jv != load{*V~~(x+1)})
|
||||||
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
|
@for (xv, xp in *V~~(x-1) over _ from 1 to b) {
|
||||||
|
jv = min{jv, xv}; mv = max{mv, xv}
|
||||||
|
dc -= xp != xv
|
||||||
|
}
|
||||||
@for (x over _ from rv to r) { if (x<min_allowed) return{x}; if (x>mx) mx=x }
|
@for (x over _ from rv to r) { if (x<min_allowed) return{x}; if (x>mx) mx=x }
|
||||||
jt := vfold{min, jv}
|
jt := vfold{min, jv}
|
||||||
mt := vfold{max, mv}
|
mt := vfold{max, mv}
|
||||||
if (jt < min_allowed) return{jt}
|
if (jt < min_allowed) return{jt}
|
||||||
if (mt > mx) mx = mt
|
if (mt > mx) mx = mt
|
||||||
|
|
||||||
|
# Fast cases
|
||||||
|
dt := promote{u64, fold_addw{dc}}
|
||||||
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
|
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
|
||||||
if (nc <= 24*vbits/128) {
|
if (dt < b * (vec/2) and dt*8 < b * promote{u64,nc}) {
|
||||||
|
r0 = count_with_runs{V, vec, x, tab, r}
|
||||||
|
} else if (nc <= 24*vbits/128) {
|
||||||
r0 = rv
|
r0 = rv
|
||||||
j0 := promote{u64, uT~~jt} # Starting count
|
count_by_sum{T, V, [vec]uT, xv, b, tab, r0,
|
||||||
m := promote{u64, nc} # Number of iterations
|
promote{u64, uT~~jt}, # Starting count
|
||||||
total := trunc{usz, r0} # To compute last count
|
promote{u64, nc} # Number of iterations
|
||||||
def count_each{js, num} = {
|
|
||||||
j := @collect (k to num) trunc{T, js+k}
|
|
||||||
c := copy{length{j}, [vec]uT ** 0}
|
|
||||||
e := each{{j}=>V**j, j}
|
|
||||||
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
|
|
||||||
def add_sum{c, j} = {
|
|
||||||
s := promote{usz, fold_addw{c}}
|
|
||||||
total -= s; inc{tab, j, s}
|
|
||||||
}
|
|
||||||
each{add_sum, c, j}
|
|
||||||
}
|
}
|
||||||
m4 := m / 4
|
|
||||||
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
|
||||||
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
|
||||||
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Scalar fallback and cleanup
|
# Scalar fallback and cleanup
|
||||||
@ -68,4 +63,52 @@ fn count{T}(tab:*usz, x:*T, n:u64, min_allowed:T) : T = {
|
|||||||
mx
|
mx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Sum comparisons against each value (except one) in the range
|
||||||
|
def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
|
||||||
|
total := trunc{usz, r0} # To compute last count
|
||||||
|
def count_each{js, num} = {
|
||||||
|
j := @collect (k to num) trunc{T, js+k}
|
||||||
|
c := copy{length{j}, U**0}
|
||||||
|
e := each{{j}=>V**j, j}
|
||||||
|
@for (xv over b) each{{c,e} => c -= xv == e, c, e}
|
||||||
|
def add_sum{c, j} = {
|
||||||
|
s := promote{usz, fold_addw{c}}
|
||||||
|
total -= s; inc{tab, j, s}
|
||||||
|
}
|
||||||
|
each{add_sum, c, j}
|
||||||
|
}
|
||||||
|
m4 := m / 4
|
||||||
|
@for (j4 to m4) count_each{j0 + 4*j4, 4}
|
||||||
|
@for (j from 4*m4 to m) count_each{j0 + j, 1}
|
||||||
|
inc{tab, trunc{T, j0 + m}, trunc{usz,total}}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Count adjacent equal elements at once, breaking at w-element groups
|
||||||
|
# May read up to index r from x, hitting one element that's not counted
|
||||||
|
def count_with_runs{V, vec, x, tab, r} = {
|
||||||
|
def w = width{ux}
|
||||||
|
m0:ux = 1 << (w-1) # Last element in each chunk ends a run
|
||||||
|
bw := r / w
|
||||||
|
@for (i to bw) {
|
||||||
|
xo := x + i*w
|
||||||
|
m := m0
|
||||||
|
# Mark the end of each run
|
||||||
|
@unroll (j to w / vec) {
|
||||||
|
def jv = j*vec
|
||||||
|
def lv{k} = load{*V~~(xo + k)}
|
||||||
|
m |= promote{ux, homMask{lv{jv} != lv{jv+1}}} << jv
|
||||||
|
}
|
||||||
|
# Iterate over runs
|
||||||
|
jp:usz = - usz~~1
|
||||||
|
while (m > m0) @unroll (2) {
|
||||||
|
j := trunc{usz, ctz{m}}
|
||||||
|
inc{tab, load{xo, j}, j - jp}
|
||||||
|
jp = j; m &= m-1
|
||||||
|
}
|
||||||
|
# One step if popc{m} was odd, reducing branch mispredictions above
|
||||||
|
inc{tab, load{xo, w-1}, ((w-1) - jp) & -trunc{usz, m>>(w-1)}}
|
||||||
|
}
|
||||||
|
bw * w
|
||||||
|
}
|
||||||
|
|
||||||
export{'avx2_count_i8', count{i8}}
|
export{'avx2_count_i8', count{i8}}
|
||||||
|
|||||||
@ -15,6 +15,8 @@ def ntyp{S, ...S2, T if w128{T}} = merge{S, 'q', ...S2, '_', nty{T}}
|
|||||||
def ntyp{S, ...S2, T if w64{T}} = merge{S, ...S2, '_', nty{T}}
|
def ntyp{S, ...S2, T if w64{T}} = merge{S, ...S2, '_', nty{T}}
|
||||||
def ntyp0{S, T} = merge{S, '_', nty{T}}
|
def ntyp0{S, T} = merge{S, '_', nty{T}}
|
||||||
|
|
||||||
|
def __neg{a:T if nvecu{T}} = T~~(-ty_s{T}~~a)
|
||||||
|
|
||||||
def __lt{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vcltz', T}, a}
|
def __lt{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vcltz', T}, a}
|
||||||
def __le{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vclez', T}, a}
|
def __le{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vclez', T}, a}
|
||||||
def __gt{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vcgtz', T}, a}
|
def __gt{a:T, 0 if nvecs{T} or nvecf{T}} = emit{ty_u{T}, ntyp{'vcgtz', T}, a}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user