From 9d7d330a03b99b0fba53ce0d2e68ebfcbf5afb8f Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 17 Mar 2023 15:44:31 -0400 Subject: [PATCH] Use AVX2 counting for 1-byte counting sort --- src/builtins/grade.h | 7 ++++++- src/builtins/slash.c | 2 +- src/singeli/src/count.singeli | 18 +++++++++--------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/builtins/grade.h b/src/builtins/grade.h index bf270303..828a8e69 100644 --- a/src/builtins/grade.h +++ b/src/builtins/grade.h @@ -81,6 +81,9 @@ extern void (*const avx2_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2); if (e==n) {break;} k=e; \ } #define WRITE_SPARSE(T) WRITE_SPARSE_##T +extern i8 (*const avx2_count_i8)(usz*, i8*, u64, i8); +#define SINGELI_COUNT_OR(T) \ + if (1==sizeof(T)) avx2_count_i8(c0o, (i8*)xp, n, -128); else #else #define COUNT_THRESHOLD 16 #define WRITE_SPARSE(T) \ @@ -88,13 +91,14 @@ extern void (*const avx2_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2); usz js = j; \ while (ijmx) mx=x } + @for (x over _ from rv to r) { if (xmx) mx=x } jt := fold{min, jv} mt := fold{max, mv} - if (jt < 0) return{jt} + if (jt < min_allowed) return{jt} if (mt > mx) mx = mt - nc := mt - jt # Number of counts to perform: last is implicit + nc := uT~~(mt - jt) # Number of counts to perform: last is implicit if (nc <= 48) { r0 = rv j0 := promote{u64, uT~~jt} # Starting count - m := promote{u64, uT~~nc} # Number of iterations + m := promote{u64, nc} # Number of iterations total := trunc{usz, r0} # To compute last count def count_each{js, num} = { - j := @collect (k to num) js+k + j := @collect (k to num) trunc{T, js+k} c := copy{tuplen{j}, [vec]uT ** 0} - e := each{{j}=>V**trunc{T, j}, j} + e := each{{j}=>V**j, j} @for (xv over b) each{{c,e} => c -= xv == e, c, e} def add_sum{c, j} = { s := promote{usz, sum_vec{V}(V~~c)} @@ -61,7 +61,7 @@ fn count{T}(tab:*usz, x:*T, n:u64) : T = { m4 := m / 4 @for (j4 to m4) count_each{j0 + 4*j4, 4} @for (j from 4*m4 to m) count_each{j0 + j, 1} - inc{tab, j0 + m, trunc{usz,total}} + inc{tab, trunc{T, j0 + m}, trunc{usz,total}} } # Scalar fallback and cleanup