Fast /⁼ of sorted arguments using semi-sparse representation

This commit is contained in:
Marshall Lochbaum 2024-11-15 20:55:15 -05:00
parent 11117fcc67
commit e6940e73d0
2 changed files with 109 additions and 18 deletions

View File

@ -812,6 +812,30 @@ B slash_c2(B t, B w, B x) {
return c2rt(slash, w, x); return c2rt(slash, w, x);
} }
#if SINGELI_SIMD
static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
// Overflow values in ov are sorted but not unique
// Set mo to the greatest sum of oc for equal ov values
usz mo = 0, pv = 0, c = 0;
for (usz i=0; i<on; i++) {
usz sv = pv; pv = ov[i];
c = c*(sv==pv) + oc[i];
if (c>mo) mo=c;
}
// Since mo is a multiple of 128 and all of r is less than 128,
// values in r can't affect the result type
#define RESIZE(T, UT) \
r = taga(cpy##UT##Arr(r)); T* rp = tyany_ptr(r); \
for (usz i=0; i<on; i++) rp[ov[i]]+= oc[i];
if (mo == 0); // No overflow, r is correct already
else if (mo < I16_MAX) { RESIZE(i16, I16) }
else if (mo < I32_MAX) { RESIZE(i32, I32) }
else { RESIZE(f64, F64) }
#undef RESIZE
return r;
}
#endif
B slash_im(B t, B x) { B slash_im(B t, B x) {
if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be a list"); if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be a list");
u8 xe = TI(x,elType); u8 xe = TI(x,elType);
@ -831,6 +855,7 @@ B slash_im(B t, B x) {
usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \ usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \
u##N max=xp[a-1]; \ u##N max=xp[a-1]; \
if (a<xia) { \ if (a<xia) { \
SINGELI_COUNT_SORTED(N) \
for (usz i=a; i<xia; i++) { u##N c=xp[i]; if (c>max) max=c; } \ for (usz i=a; i<xia; i++) { u##N c=xp[i]; if (c>max) max=c; } \
if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \ if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \
usz ria = max + 1; \ usz ria = max + 1; \
@ -862,7 +887,7 @@ B slash_im(B t, B x) {
i##N* xp = i##N##any_ptr(x); \ i##N* xp = i##N##any_ptr(x); \
usz m=1<<N; \ usz m=1<<N; \
usz mh = m/2, sa = SINGELI_COUNT_ALLOC; \ usz mh = m/2, sa = SINGELI_COUNT_ALLOC; \
if (xia < mh) { \ if (xia < mh || HAS_SINGELI_COUNT_SORTED) { \
TRY_SMALL_OUT(N) \ TRY_SMALL_OUT(N) \
if (RIA_SMALL(N)) { sa=mh=ria; goto small_range##N; } \ if (RIA_SMALL(N)) { sa=mh=ria; goto small_range##N; } \
INIT_RES(N) FILL_RES \ INIT_RES(N) FILL_RES \
@ -883,6 +908,18 @@ B slash_im(B t, B x) {
i##N max = simd_count_i##N(t, xp, xia, 0); \ i##N max = simd_count_i##N(t, xp, xia, 0); \
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \ if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
usz ria=max+1; usz ria=max+1;
#define HAS_SINGELI_COUNT_SORTED FL_HAS(x,fl_asc)
#define SINGELI_COUNT_SORTED(N) \
if (FL_HAS(x,fl_asc)) { \
usz ria = xp[xia-1] + 1; \
usz os = xia/128; \
INIT_RES(8) \
TALLOC(usz, ov, 2*os); usz* oc = ov+os; \
usz on = si_count_sorted_i##N((u8*)rp, ov, oc, xp, xia); \
r = finish_sorted_count(r, ov, oc, on); \
TFREE(ov); \
break; \
}
#else #else
#define RIA_SMALL(N) 0 #define RIA_SMALL(N) 0
#define SINGELI_COUNT_ALLOC m #define SINGELI_COUNT_ALLOC m
@ -890,6 +927,8 @@ B slash_im(B t, B x) {
for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \ for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \
t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \ t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \
if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers"); if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers");
#define HAS_SINGELI_COUNT_SORTED 0
#define SINGELI_COUNT_SORTED(N)
#endif #endif
CASE_SMALL(8) CASE_SMALL(16) CASE_SMALL(8) CASE_SMALL(16)
#undef CASE_SMALL #undef CASE_SMALL

View File

@ -91,17 +91,36 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
} }
# Count adjacent equal elements at once, breaking at w-element groups # Count adjacent equal elements at once, breaking at w-element groups
# May read up to index r from x, hitting one element that's not counted # May read up to index n from x, hitting one element that's not counted
def count_with_runs{x, tab, r} = { def count_with_runs{x, tab, n} = {
def w = width{ux} def w = width{ux}
m0:ux = 1 << (w-1) # Last element in each chunk ends a run m0:ux = 1 << (w-1) # Last element in each chunk ends a run
bw := r / w bw := n / w
@for (i to bw) { @for (i to bw) {
xo := x + i*w xo := x + i*w
m := m0; mark_run_ends{xo, m} m := m0; mark_run_ends{xo, m}
inc_marked_runs{xo, tab, m, m0} inc_marked_runs{xo, tab, m, m0}
} }
bw * w bw * w # Number of elements handled
}
# Switch to the normal scalar count if there aren't enough runs
def count_adapt_runs{x0, tab, n} = {
def w = width{ux}
m0:ux = 1 << (w-1)
x := x0; r := n
while (r > 0) {
def skip_runs = makelabel{}
b:usz = w
if (rare{b > r}) { b = r; goto{skip_runs} }
m := m0; mark_run_ends{x, m}
if (popc{m} < w/2) {
inc_marked_runs{x, tab, m, m0}
} else {
setlabel{skip_runs}
@for (x over b) inc{tab, x}
}
x += b; r -= b
}
} }
def mark_run_ends{x:*T, m:(ux)} = { def mark_run_ends{x:*T, m:(ux)} = {
def vec = arch_defvw/width{T} def vec = arch_defvw/width{T}
@ -126,23 +145,56 @@ def inc_marked_runs{x, tab:*T, m, m0} = {
} }
# No count_by_sum: build each run mask then decide whether to use it # No count_by_sum: build each run mask then decide whether to use it
fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = { fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = count_adapt_runs{x, tab, n}
def w = width{ux}
m0:ux = 1 << (w-1) # For i←/⁼x, store r←128|i, and i-r sparsely: x is ∧(/r)∾oc/ov
while (n > 0) { # ov is sorted but may not be unique, and oc contains multiples of 128
b:usz = w # Return the shared length of ov and oc
if (rare{b > n}) { b = n; goto{'skip_runs'} } fn count_sorted{T}(r:*u8, ov:*usz, oc:*usz, x:*T, n:usz) : usz = {
m := m0; mark_run_ends{x, m} def V = [arch_defvw/width{T}]T
if (popc{m} < w/2) { def block = 128
inc_marked_runs{x, tab, m, m0} i:usz = 0
} else { on:usz = 0
setlabel{'skip_runs'} def overflow{xu,c} = { store{ov, on, xu}; store{oc, on, c}; ++on }
@for (x over b) inc{tab, x} while (i < n) {
rem := n - i
xo := x + i
xi := load{xo}
def overflow{c} = overflow{cast_i{usz,xi}, c}
xe := xo-1; def bxi{j} = xi == load{xe, j}
if (block <= rem and bxi{block}) {
# Gallop to find last block ending in xi
d:usz = block
d2 := undefined{usz}
while ((d2=d+d) <= rem and bxi{d2}) d = d2
l := (rem &~ (block-1)) - d; if (l > d) l = d
# Target is in [d,d+l); shrink l
while (l > block) {
h := (l/2) &~ (block-1)
m := d + h
if (bxi{m}) d = m
l -= h
}
overflow{d}
rem -= d; if (rem == 0) return{on}
i += d; xo += d; xi = load{xo}
} }
x += b; n -= b # Count the next block normally
if (rem > block) rem = block
count_adapt_runs{xo, r, rem}
rxi := load{r, xi}
if (rxi >= block) {
store{r, xi, rxi - block}
overflow{block}
}
i += rem
} }
on
} }
export{'simd_count_i8', count{i8}} export{'simd_count_i8', count{i8}}
export{'simd_count_i16', count{i16}} export{'simd_count_i16', count{i16}}
export{'simd_count_i32_i32', count_i32_i32} export{'simd_count_i32_i32', count_i32_i32}
export{'si_count_sorted_i8', count_sorted{i8}}
export{'si_count_sorted_i16', count_sorted{i16}}
export{'si_count_sorted_i32', count_sorted{i32}}