Clean up and simplify count.singeli; allow longer final block

This commit is contained in:
Marshall Lochbaum 2023-03-17 13:22:36 -04:00
parent 06b4f06e64
commit 4b18466ae2

View File

@ -17,36 +17,37 @@ fn count{T}(tab:*usz, x:*ty_u{T}, n:u64) : u1 = {
def vbits = 256 def vbits = 256
def vec = vbits/width{T} def vec = vbits/width{T}
def uT = ty_u{T} def uT = ty_u{T}
def V = [vec]uT def V = [vec]T
def iV = [vec]T
def block = (2048*8) / vbits # Target vectors per block def block = (2048*8) / vbits # Target vectors per block
assert{block < 1<<width{T}} # Don't overflow count in vector section def b_max = block + block/4 # Last block max length
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
i:u64 = 0 i:u64 = 0
while (i < n) { while (i < n) {
r:u64 = n - i # Number of elements to handle in this iteration
b := r / vec r:u64 = n - i; if (r > vec*b_max) r = vec*block
b := r / vec # Vector case does b full vectors if it runs
r0:u64 = 0 # Elements actually handled by vector case
# Find range to check for suitability
xv := *V~~x xv := *V~~x
used_eq:u1 = 0
if (r >= 128) {
b = block; if (r < vec*b) b = r / vec
jv := load{xv}; mv := jv jv := load{xv}; mv := jv
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} } @for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
mi := iV~~mv
if (homAny{mi < iV**0}) return{1}
jt := fold{min, jv} jt := fold{min, jv}
if (homAll{mi <= iV**(48 + i8~~jt)}) { mt := fold{max, mv} - jt # Counts needed (last one's implicit)
used_eq = 1 if (jt < 0) return{1} # Negative number found!
r = b * vec
j0 := promote{u64, jt} if (mt <= 48) {
m := promote{u64, fold{max, mv}} - j0 r0 = b * vec
total := trunc{usz, b*vec} j0 := promote{u64, uT~~jt} # Starting count
m := promote{u64, uT~~mt} # Number of iterations
total := trunc{usz, r0} # To compute last count
def count_each{js, num} = { def count_each{js, num} = {
j := (@collect (k to num) js+k) j := @collect (k to num) js+k
c := copy{tuplen{j}, V**0} c := copy{tuplen{j}, [vec]uT ** 0}
e := each{{j}=>V**trunc{uT, j}, j} e := each{{j}=>V**trunc{T, j}, j}
@for (xv over b) each{{c,e} => c -= xv == e, c, e} @for (xv over b) each{{c,e} => c -= xv == e, c, e}
def add_sum{c, j} = { def add_sum{c, j} = {
s := promote{usz, sum_vec{iV}(iV~~c)} s := promote{usz, sum_vec{V}(V~~c)}
total -= s; inc{tab, j, s} total -= s; inc{tab, j, s}
} }
each{add_sum, c, j} each{add_sum, c, j}
@ -56,8 +57,9 @@ fn count{T}(tab:*usz, x:*ty_u{T}, n:u64) : u1 = {
@for (j from 4*m4 to m) count_each{j0 + j, 1} @for (j from 4*m4 to m) count_each{j0 + j, 1}
inc{tab, j0 + m, trunc{usz,total}} inc{tab, j0 + m, trunc{usz,total}}
} }
}
if (not used_eq) @for (x over r) inc{tab, x} # Scalar fallback and cleanup
@for (x over _ from r0 to r) inc{tab, x}
i += r i += r
x += r x += r
} }