diff --git a/src/builtins/grade.h b/src/builtins/grade.h index 406a21e6..06eb3ee2 100644 --- a/src/builtins/grade.h +++ b/src/builtins/grade.h @@ -97,9 +97,27 @@ extern void (*const si_scan_min_i16)(int16_t* v0,int16_t* v1,uint64_t v2); if (e==n) {break;} k=e; \ } #define WRITE_SPARSE(T) WRITE_SPARSE_##T -extern i8 (*const simd_count_i8)(usz*, i8*, u64, i8); -#define SINGELI_COUNT_OR(T) \ - if (1==sizeof(T)) simd_count_i8(c0o, (i8*)xp, n, -128); else +extern i8 (*const simd_count_i8)(u16*, u16*, void*, u64, i8); +#define COUNTING_SORT_i8 \ + usz C=1<<8; \ + TALLOC(u16, c0, C+(n>>15)+1); \ + u16 *c0o=c0+C/2; u16 *ov=c0+C; \ + for (usz j=0; j 0)) { // Overflowed i32! + r = taga(cpyF64Arr(r)); f64* rp = tyany_ptr(r); + for (usz i=0; ixp[a-1]) a++; \ u##N max=xp[a-1]; \ + usz rmax=xia; \ if (a=16 && maxcount<128) { INIT_RES(8) FILL_RES break; } \ + else if (N>=16 && maxcount<128) rmax=127; \ } \ } \ + usz ria = (usz)max + 1; \ if (a==xia) { /* Unique argument */ \ - usz ria = max + 1; \ u64* rp; r = m_bitarrv(&rp, ria); \ for (usz i=0; i>15; \ + TALLOC(u16, ov, os+1); \ + i##N max = simd_count_i##N((u16*)rp, (u16*)ov, xp, xia, 0); \ + if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \ + usz ria = (usz)max + 1; \ + if (ria < sa) r = C2(take, m_f64(ria), r); \ + r = finish_small_count(r, ov); \ + TFREE(ov); \ + break; \ } CASE_SMALL(8) CASE_SMALL(16) #undef CASE_SMALL case el_i32: { i32* xp = i32any_ptr(x); TRY_SMALL_OUT(32) - INIT_RES(32) + INIT_RES(32,ria) simd_count_i32_i32(rp, xp, xia); r = num_squeeze(r); break; } #undef TRY_SMALL_OUT #undef INIT_RES - #undef FILL_RES #else #define CASE(N) case el_i##N: { \ i##N* xp = i##N##any_ptr(x); \ diff --git a/src/singeli/src/count.singeli b/src/singeli/src/count.singeli index 0152ae26..7f23fa9e 100644 --- a/src/singeli/src/count.singeli +++ b/src/singeli/src/count.singeli @@ -6,31 +6,24 @@ if_inline (hasarch{'SSE2'}) { def fold_addw{v:T=[_]E if E<=u32} = sum_vec{T}(v) } -def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}} +def inc{ptr:*T, ind, v} = store{ptr, ind, trunc{T,v} + load{ptr, ind}} def inc{ptr, ind} = inc{ptr, ind, 1} -def block_loop{V=[vec]T, n, iter} = { - def block = (2048*8) / width{V} # Target vectors per block - def b_max = block + block/4 # Last block max length - assert{b_max < 1< vec*b_max) r = vec*block - iter{r} - i += r - } -} - -# Write counts /⁼x to tab and return ⌈´x -fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = { +# Write counts (2⋆15)|/⁼x to tab, overflows to ov, and return ⌈´x +fn count{T if T<=i16}(tab:*u16, ov:*u16, xp:*void, n:u64, min_allowed:T) : T = { def vbits = arch_defvw def vec = vbits/width{T} def uT = ty_u{T} def V = [vec]T + def block = (2048*8) / vbits # Target vectors per block + def b_max = block + block/4 # Last block max length + assert{b_max < 1< { # Handle r elements + i:u64 = 0 + while (i < n) { + # Number of elements to handle in this iteration + r:u64 = n - i; if (r > vec*b_max) r = vec*block b := r / vec # Vector case does b full vectors if it runs rv:= b * vec r0:u64 = 0 # Elements actually handled by vector case @@ -65,11 +58,37 @@ fn count{T}(tab:*usz, xp:*void, n:u64, min_allowed:T) : T = { # Scalar fallback and cleanup @for (x over _ from r0 to r) inc{tab, x} + i += r x += r - }} + + # Keep counts below 1<<15 with the overflow list + # Count from the end to include i==n and handle a long last block nicely + if ((i-n)%(1<<15) < block*vec and i >= 1<<15) { + ov += flush_counts(tab+min_allowed, ov, cast_i{usz,ty_u{mx+min_allowed}} + 1) + } + } + store{ov, 0, maxvalue{u16}} # End marker: note x values fit in i16 mx } +fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = { + def vl = arch_defvw/16 + def V = [vl]u16 + def bot = 1<<15 - 1 + on:usz = 0 + @for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{topAny{t}}) { + o := if (hasarch{'X86_64'}) topMask{t} else homMask{t > V**bot} + if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1 + while (o > 0) { + jv := jv*vl + cast_i{usz, ctz{o}} + store{tab, jv, load{tab, jv} & bot} + store{ov, on, trunc{u16, jv}}; ++on + o &= o-1 + } + } + on +} + # Sum comparisons against each value (except one) in the range def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = { total := trunc{usz, r0} # To compute last count @@ -87,7 +106,7 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = { m4 := m / 4 @for (j4 to m4) count_each{j0 + 4*j4, 4} @for (j from 4*m4 to m) count_each{j0 + j, 1} - inc{tab, trunc{T, j0 + m}, trunc{usz,total}} + inc{tab, trunc{T, j0 + m}, total} } # Count adjacent equal elements at once, breaking at w-element groups @@ -137,7 +156,7 @@ def inc_marked_runs{x, tab:*T, m, m0} = { jp:T = - T~~1 while (m > m0) @unroll (2) { j := trunc{T, ctz{m}} - inc{tab, load{x, j}, cast_i{T, j - jp}} + inc{tab, load{x, j}, j - jp} jp = j; m &= m-1 } # One step if popc{m} was odd, reducing branch mispredictions above