Version of count_i32_i32 without large blocks

This commit is contained in:
Marshall Lochbaum 2024-11-15 17:37:11 -05:00
parent 3b103aadd0
commit 11117fcc67

View File

@ -54,7 +54,7 @@ fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = {
dt := promote{u64, fold_addw{dc}}
nc := uT~~(mt - jt) # Number of counts to perform: last is implicit
if (dt < b * (vec/2) and dt*8 < b * promote{u64,nc}) {
r0 = count_with_runs{V, vec, x, tab, r}
r0 = count_with_runs{x, tab, r}
} else if (nc <= 24*vbits/128) {
r0 = rv
count_by_sum{T, V, [vec]uT, xv, b, tab, r0,
@ -92,49 +92,55 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
# Count adjacent equal elements at once, breaking at w-element groups
# May read up to index r from x, hitting one element that's not counted
def count_with_runs{V, vec, x, tab:*T, r} = {
def count_with_runs{x, tab, r} = {
def w = width{ux}
m0:ux = 1 << (w-1) # Last element in each chunk ends a run
bw := r / w
@for (i to bw) {
xo := x + i*w
m := m0
# Mark the end of each run
@unroll (j to w / vec) {
def jv = j*vec
def lv{k} = load{*V~~(xo + k)}
m |= promote{ux, homMask{lv{jv} != lv{jv+1}}} << jv
}
# Iterate over runs
jp:T = - T~~1
while (m > m0) @unroll (2) {
j := trunc{T, ctz{m}}
inc{tab, load{xo, j}, cast_i{T, j - jp}}
jp = j; m &= m-1
}
# One step if popc{m} was odd, reducing branch mispredictions above
inc{tab, load{xo, w-1}, ((w-1) - jp) & -trunc{T, m>>(w-1)}}
m := m0; mark_run_ends{xo, m}
inc_marked_runs{xo, tab, m, m0}
}
bw * w
}
# Condensed version without count_by_sum
fn count_i32_i32(tab:*i32, x:*i32, n:u64) : void = {
def T = i32
def vbits = arch_defvw
def vec = vbits/width{T}
def mark_run_ends{x:*T, m:(ux)} = {
def vec = arch_defvw/width{T}
def V = [vec]T
block_loop{V, n, {r} => {
b := r / vec
xv := *V~~x
dc := -(load{xv} != load{*V~~(x+1)})
@for (xv, xp in *V~~(x-1) over _ from 1 to b) dc -= xp != xv
dt := promote{u64, fold_addw{dc}}
r0:u64 = 0
if (dt < b * (vec/2)) r0 = count_with_runs{V, vec, x, tab, r}
@for (x over _ from r0 to r) inc{tab, x}
x += r
}}
@unroll (j to width{ux} / vec) {
def jv = j*vec
def lv{k} = load{*V~~(x + k)}
m |= promote{ux, homMask{lv{jv} != lv{jv+1}}} << jv
}
}
def inc_marked_runs{x, tab:*T, m, m0} = {
def w = width{ux}
# Iterate over runs marked in m
jp:T = - T~~1
while (m > m0) @unroll (2) {
j := trunc{T, ctz{m}}
inc{tab, load{x, j}, cast_i{T, j - jp}}
jp = j; m &= m-1
}
# One step if popc{m} was odd, reducing branch mispredictions above
inc{tab, load{x, w-1}, ((w-1) - jp) & -trunc{T, m>>(w-1)}}
}
# No count_by_sum: build each run mask then decide whether to use it
fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = {
def w = width{ux}
m0:ux = 1 << (w-1)
while (n > 0) {
b:usz = w
if (rare{b > n}) { b = n; goto{'skip_runs'} }
m := m0; mark_run_ends{x, m}
if (popc{m} < w/2) {
inc_marked_runs{x, tab, m, m0}
} else {
setlabel{'skip_runs'}
@for (x over b) inc{tab, x}
}
x += b; n -= b
}
}
export{'simd_count_i8', count{i8}}