From 4b18466ae296c650aaee856617cd2fa83410fbd2 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 17 Mar 2023 13:22:36 -0400 Subject: [PATCH] Clean up and simplify count.singeli; allow longer final block --- src/singeli/src/count.singeli | 70 ++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/src/singeli/src/count.singeli b/src/singeli/src/count.singeli index bc393e69..28fa8688 100644 --- a/src/singeli/src/count.singeli +++ b/src/singeli/src/count.singeli @@ -17,47 +17,49 @@ fn count{T}(tab:*usz, x:*ty_u{T}, n:u64) : u1 = { def vbits = 256 def vec = vbits/width{T} def uT = ty_u{T} - def V = [vec]uT - def iV = [vec]T + def V = [vec]T def block = (2048*8) / vbits # Target vectors per block - assert{block < 1< vec*b_max) r = vec*block + b := r / vec # Vector case does b full vectors if it runs + r0:u64 = 0 # Elements actually handled by vector case + + # Find range to check for suitability xv := *V~~x - used_eq:u1 = 0 - if (r >= 128) { - b = block; if (r < vec*b) b = r / vec - jv := load{xv}; mv := jv - @for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} } - mi := iV~~mv - if (homAny{mi < iV**0}) return{1} - jt := fold{min, jv} - if (homAll{mi <= iV**(48 + i8~~jt)}) { - used_eq = 1 - r = b * vec - j0 := promote{u64, jt} - m := promote{u64, fold{max, mv}} - j0 - total := trunc{usz, b*vec} - def count_each{js, num} = { - j := (@collect (k to num) js+k) - c := copy{tuplen{j}, V**0} - e := each{{j}=>V**trunc{uT, j}, j} - @for (xv over b) each{{c,e} => c -= xv == e, c, e} - def add_sum{c, j} = { - s := promote{usz, sum_vec{iV}(iV~~c)} - total -= s; inc{tab, j, s} - } - each{add_sum, c, j} + jv := load{xv}; mv := jv + @for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} } + jt := fold{min, jv} + mt := fold{max, mv} - jt # Counts needed (last one's implicit) + if (jt < 0) return{1} # Negative number found! + + if (mt <= 48) { + r0 = b * vec + j0 := promote{u64, uT~~jt} # Starting count + m := promote{u64, uT~~mt} # Number of iterations + total := trunc{usz, r0} # To compute last count + def count_each{js, num} = { + j := @collect (k to num) js+k + c := copy{tuplen{j}, [vec]uT ** 0} + e := each{{j}=>V**trunc{T, j}, j} + @for (xv over b) each{{c,e} => c -= xv == e, c, e} + def add_sum{c, j} = { + s := promote{usz, sum_vec{V}(V~~c)} + total -= s; inc{tab, j, s} } - m4 := m / 4 - @for (j4 to m4) count_each{j0 + 4*j4, 4} - @for (j from 4*m4 to m) count_each{j0 + j, 1} - inc{tab, j0 + m, trunc{usz,total}} + each{add_sum, c, j} } + m4 := m / 4 + @for (j4 to m4) count_each{j0 + 4*j4, 4} + @for (j from 4*m4 to m) count_each{j0 + j, 1} + inc{tab, j0 + m, trunc{usz,total}} } - if (not used_eq) @for (x over r) inc{tab, x} + + # Scalar fallback and cleanup + @for (x over _ from r0 to r) inc{tab, x} i += r x += r }