From e6940e73d0b80b197be9f137e10057ad02661714 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Fri, 15 Nov 2024 20:55:15 -0500 Subject: [PATCH] =?UTF-8?q?Fast=20/=E2=81=BC=20of=20sorted=20arguments=20u?= =?UTF-8?q?sing=20semi-sparse=20representation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/slash.c | 41 ++++++++++++++++- src/singeli/src/count.singeli | 86 ++++++++++++++++++++++++++++------- 2 files changed, 109 insertions(+), 18 deletions(-) diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 69366eab..92af67d2 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -812,6 +812,30 @@ B slash_c2(B t, B w, B x) { return c2rt(slash, w, x); } +#if SINGELI_SIMD +static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) { + // Overflow values in ov are sorted but not unique + // Set mo to the greatest sum of oc for equal ov values + usz mo = 0, pv = 0, c = 0; + for (usz i=0; imo) mo=c; + } + // Since mo is a multiple of 128 and all of r is less than 128, + // values in r can't affect the result type + #define RESIZE(T, UT) \ + r = taga(cpy##UT##Arr(r)); T* rp = tyany_ptr(r); \ + for (usz i=0; ixp[a-1]) a++; \ u##N max=xp[a-1]; \ if (amax) max=c; } \ if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \ usz ria = max + 1; \ @@ -862,7 +887,7 @@ B slash_im(B t, B x) { i##N* xp = i##N##any_ptr(x); \ usz m=1<m/2) thrM("/⁼: Argument cannot contain negative numbers"); + #define HAS_SINGELI_COUNT_SORTED 0 + #define SINGELI_COUNT_SORTED(N) #endif CASE_SMALL(8) CASE_SMALL(16) #undef CASE_SMALL diff --git a/src/singeli/src/count.singeli b/src/singeli/src/count.singeli index aec7e8a0..0152ae26 100644 --- a/src/singeli/src/count.singeli +++ b/src/singeli/src/count.singeli @@ -91,17 +91,36 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = { } # Count adjacent equal elements at once, breaking at w-element groups -# May read up to index r from x, hitting one element that's not counted -def count_with_runs{x, tab, r} = { +# May read up to index n from x, hitting one element that's not counted +def count_with_runs{x, tab, n} = { def w = width{ux} m0:ux = 1 << (w-1) # Last element in each chunk ends a run - bw := r / w + bw := n / w @for (i to bw) { xo := x + i*w m := m0; mark_run_ends{xo, m} inc_marked_runs{xo, tab, m, m0} } - bw * w + bw * w # Number of elements handled +} +# Switch to the normal scalar count if there aren't enough runs +def count_adapt_runs{x0, tab, n} = { + def w = width{ux} + m0:ux = 1 << (w-1) + x := x0; r := n + while (r > 0) { + def skip_runs = makelabel{} + b:usz = w + if (rare{b > r}) { b = r; goto{skip_runs} } + m := m0; mark_run_ends{x, m} + if (popc{m} < w/2) { + inc_marked_runs{x, tab, m, m0} + } else { + setlabel{skip_runs} + @for (x over b) inc{tab, x} + } + x += b; r -= b + } } def mark_run_ends{x:*T, m:(ux)} = { def vec = arch_defvw/width{T} @@ -126,23 +145,56 @@ def inc_marked_runs{x, tab:*T, m, m0} = { } # No count_by_sum: build each run mask then decide whether to use it -fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = { - def w = width{ux} - m0:ux = 1 << (w-1) - while (n > 0) { - b:usz = w - if (rare{b > n}) { b = n; goto{'skip_runs'} } - m := m0; mark_run_ends{x, m} - if (popc{m} < w/2) { - inc_marked_runs{x, tab, m, m0} - } else { - setlabel{'skip_runs'} - @for (x over b) inc{tab, x} +fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = count_adapt_runs{x, tab, n} + +# For i←/⁼x, store r←128|i, and i-r sparsely: x is ∧(/r)∾oc/ov +# ov is sorted but may not be unique, and oc contains multiples of 128 +# Return the shared length of ov and oc +fn count_sorted{T}(r:*u8, ov:*usz, oc:*usz, x:*T, n:usz) : usz = { + def V = [arch_defvw/width{T}]T + def block = 128 + i:usz = 0 + on:usz = 0 + def overflow{xu,c} = { store{ov, on, xu}; store{oc, on, c}; ++on } + while (i < n) { + rem := n - i + xo := x + i + xi := load{xo} + def overflow{c} = overflow{cast_i{usz,xi}, c} + xe := xo-1; def bxi{j} = xi == load{xe, j} + if (block <= rem and bxi{block}) { + # Gallop to find last block ending in xi + d:usz = block + d2 := undefined{usz} + while ((d2=d+d) <= rem and bxi{d2}) d = d2 + l := (rem &~ (block-1)) - d; if (l > d) l = d + # Target is in [d,d+l); shrink l + while (l > block) { + h := (l/2) &~ (block-1) + m := d + h + if (bxi{m}) d = m + l -= h + } + overflow{d} + rem -= d; if (rem == 0) return{on} + i += d; xo += d; xi = load{xo} } - x += b; n -= b + # Count the next block normally + if (rem > block) rem = block + count_adapt_runs{xo, r, rem} + rxi := load{r, xi} + if (rxi >= block) { + store{r, xi, rxi - block} + overflow{block} + } + i += rem } + on } export{'simd_count_i8', count{i8}} export{'simd_count_i16', count{i16}} export{'simd_count_i32_i32', count_i32_i32} +export{'si_count_sorted_i8', count_sorted{i8}} +export{'si_count_sorted_i16', count_sorted{i16}} +export{'si_count_sorted_i32', count_sorted{i32}}