Fast /⁼ of sorted arguments using semi-sparse representation
This commit is contained in:
parent
11117fcc67
commit
e6940e73d0
@ -812,6 +812,30 @@ B slash_c2(B t, B w, B x) {
|
|||||||
return c2rt(slash, w, x);
|
return c2rt(slash, w, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if SINGELI_SIMD
|
||||||
|
static B finish_sorted_count(B r, usz* ov, usz* oc, usz on) {
|
||||||
|
// Overflow values in ov are sorted but not unique
|
||||||
|
// Set mo to the greatest sum of oc for equal ov values
|
||||||
|
usz mo = 0, pv = 0, c = 0;
|
||||||
|
for (usz i=0; i<on; i++) {
|
||||||
|
usz sv = pv; pv = ov[i];
|
||||||
|
c = c*(sv==pv) + oc[i];
|
||||||
|
if (c>mo) mo=c;
|
||||||
|
}
|
||||||
|
// Since mo is a multiple of 128 and all of r is less than 128,
|
||||||
|
// values in r can't affect the result type
|
||||||
|
#define RESIZE(T, UT) \
|
||||||
|
r = taga(cpy##UT##Arr(r)); T* rp = tyany_ptr(r); \
|
||||||
|
for (usz i=0; i<on; i++) rp[ov[i]]+= oc[i];
|
||||||
|
if (mo == 0); // No overflow, r is correct already
|
||||||
|
else if (mo < I16_MAX) { RESIZE(i16, I16) }
|
||||||
|
else if (mo < I32_MAX) { RESIZE(i32, I32) }
|
||||||
|
else { RESIZE(f64, F64) }
|
||||||
|
#undef RESIZE
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
B slash_im(B t, B x) {
|
B slash_im(B t, B x) {
|
||||||
if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be a list");
|
if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be a list");
|
||||||
u8 xe = TI(x,elType);
|
u8 xe = TI(x,elType);
|
||||||
@ -831,6 +855,7 @@ B slash_im(B t, B x) {
|
|||||||
usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \
|
usz a=1; while (a<xia && xp[a]>xp[a-1]) a++; \
|
||||||
u##N max=xp[a-1]; \
|
u##N max=xp[a-1]; \
|
||||||
if (a<xia) { \
|
if (a<xia) { \
|
||||||
|
SINGELI_COUNT_SORTED(N) \
|
||||||
for (usz i=a; i<xia; i++) { u##N c=xp[i]; if (c>max) max=c; } \
|
for (usz i=a; i<xia; i++) { u##N c=xp[i]; if (c>max) max=c; } \
|
||||||
if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
if ((i##N)max<0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||||
usz ria = max + 1; \
|
usz ria = max + 1; \
|
||||||
@ -862,7 +887,7 @@ B slash_im(B t, B x) {
|
|||||||
i##N* xp = i##N##any_ptr(x); \
|
i##N* xp = i##N##any_ptr(x); \
|
||||||
usz m=1<<N; \
|
usz m=1<<N; \
|
||||||
usz mh = m/2, sa = SINGELI_COUNT_ALLOC; \
|
usz mh = m/2, sa = SINGELI_COUNT_ALLOC; \
|
||||||
if (xia < mh) { \
|
if (xia < mh || HAS_SINGELI_COUNT_SORTED) { \
|
||||||
TRY_SMALL_OUT(N) \
|
TRY_SMALL_OUT(N) \
|
||||||
if (RIA_SMALL(N)) { sa=mh=ria; goto small_range##N; } \
|
if (RIA_SMALL(N)) { sa=mh=ria; goto small_range##N; } \
|
||||||
INIT_RES(N) FILL_RES \
|
INIT_RES(N) FILL_RES \
|
||||||
@ -883,6 +908,18 @@ B slash_im(B t, B x) {
|
|||||||
i##N max = simd_count_i##N(t, xp, xia, 0); \
|
i##N max = simd_count_i##N(t, xp, xia, 0); \
|
||||||
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||||
usz ria=max+1;
|
usz ria=max+1;
|
||||||
|
#define HAS_SINGELI_COUNT_SORTED FL_HAS(x,fl_asc)
|
||||||
|
#define SINGELI_COUNT_SORTED(N) \
|
||||||
|
if (FL_HAS(x,fl_asc)) { \
|
||||||
|
usz ria = xp[xia-1] + 1; \
|
||||||
|
usz os = xia/128; \
|
||||||
|
INIT_RES(8) \
|
||||||
|
TALLOC(usz, ov, 2*os); usz* oc = ov+os; \
|
||||||
|
usz on = si_count_sorted_i##N((u8*)rp, ov, oc, xp, xia); \
|
||||||
|
r = finish_sorted_count(r, ov, oc, on); \
|
||||||
|
TFREE(ov); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
#define RIA_SMALL(N) 0
|
#define RIA_SMALL(N) 0
|
||||||
#define SINGELI_COUNT_ALLOC m
|
#define SINGELI_COUNT_ALLOC m
|
||||||
@ -890,6 +927,8 @@ B slash_im(B t, B x) {
|
|||||||
for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \
|
for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \
|
||||||
t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \
|
t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \
|
||||||
if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers");
|
if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers");
|
||||||
|
#define HAS_SINGELI_COUNT_SORTED 0
|
||||||
|
#define SINGELI_COUNT_SORTED(N)
|
||||||
#endif
|
#endif
|
||||||
CASE_SMALL(8) CASE_SMALL(16)
|
CASE_SMALL(8) CASE_SMALL(16)
|
||||||
#undef CASE_SMALL
|
#undef CASE_SMALL
|
||||||
|
|||||||
@ -91,17 +91,36 @@ def count_by_sum{T, V, U, xv, b, tab, r0, j0, m} = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Count adjacent equal elements at once, breaking at w-element groups
|
# Count adjacent equal elements at once, breaking at w-element groups
|
||||||
# May read up to index r from x, hitting one element that's not counted
|
# May read up to index n from x, hitting one element that's not counted
|
||||||
def count_with_runs{x, tab, r} = {
|
def count_with_runs{x, tab, n} = {
|
||||||
def w = width{ux}
|
def w = width{ux}
|
||||||
m0:ux = 1 << (w-1) # Last element in each chunk ends a run
|
m0:ux = 1 << (w-1) # Last element in each chunk ends a run
|
||||||
bw := r / w
|
bw := n / w
|
||||||
@for (i to bw) {
|
@for (i to bw) {
|
||||||
xo := x + i*w
|
xo := x + i*w
|
||||||
m := m0; mark_run_ends{xo, m}
|
m := m0; mark_run_ends{xo, m}
|
||||||
inc_marked_runs{xo, tab, m, m0}
|
inc_marked_runs{xo, tab, m, m0}
|
||||||
}
|
}
|
||||||
bw * w
|
bw * w # Number of elements handled
|
||||||
|
}
|
||||||
|
# Switch to the normal scalar count if there aren't enough runs
|
||||||
|
def count_adapt_runs{x0, tab, n} = {
|
||||||
|
def w = width{ux}
|
||||||
|
m0:ux = 1 << (w-1)
|
||||||
|
x := x0; r := n
|
||||||
|
while (r > 0) {
|
||||||
|
def skip_runs = makelabel{}
|
||||||
|
b:usz = w
|
||||||
|
if (rare{b > r}) { b = r; goto{skip_runs} }
|
||||||
|
m := m0; mark_run_ends{x, m}
|
||||||
|
if (popc{m} < w/2) {
|
||||||
|
inc_marked_runs{x, tab, m, m0}
|
||||||
|
} else {
|
||||||
|
setlabel{skip_runs}
|
||||||
|
@for (x over b) inc{tab, x}
|
||||||
|
}
|
||||||
|
x += b; r -= b
|
||||||
|
}
|
||||||
}
|
}
|
||||||
def mark_run_ends{x:*T, m:(ux)} = {
|
def mark_run_ends{x:*T, m:(ux)} = {
|
||||||
def vec = arch_defvw/width{T}
|
def vec = arch_defvw/width{T}
|
||||||
@ -126,23 +145,56 @@ def inc_marked_runs{x, tab:*T, m, m0} = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# No count_by_sum: build each run mask then decide whether to use it
|
# No count_by_sum: build each run mask then decide whether to use it
|
||||||
fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = {
|
fn count_i32_i32(tab:*i32, x:*i32, n:usz) : void = count_adapt_runs{x, tab, n}
|
||||||
def w = width{ux}
|
|
||||||
m0:ux = 1 << (w-1)
|
# For i←/⁼x, store r←128|i, and i-r sparsely: x is ∧(/r)∾oc/ov
|
||||||
while (n > 0) {
|
# ov is sorted but may not be unique, and oc contains multiples of 128
|
||||||
b:usz = w
|
# Return the shared length of ov and oc
|
||||||
if (rare{b > n}) { b = n; goto{'skip_runs'} }
|
fn count_sorted{T}(r:*u8, ov:*usz, oc:*usz, x:*T, n:usz) : usz = {
|
||||||
m := m0; mark_run_ends{x, m}
|
def V = [arch_defvw/width{T}]T
|
||||||
if (popc{m} < w/2) {
|
def block = 128
|
||||||
inc_marked_runs{x, tab, m, m0}
|
i:usz = 0
|
||||||
} else {
|
on:usz = 0
|
||||||
setlabel{'skip_runs'}
|
def overflow{xu,c} = { store{ov, on, xu}; store{oc, on, c}; ++on }
|
||||||
@for (x over b) inc{tab, x}
|
while (i < n) {
|
||||||
|
rem := n - i
|
||||||
|
xo := x + i
|
||||||
|
xi := load{xo}
|
||||||
|
def overflow{c} = overflow{cast_i{usz,xi}, c}
|
||||||
|
xe := xo-1; def bxi{j} = xi == load{xe, j}
|
||||||
|
if (block <= rem and bxi{block}) {
|
||||||
|
# Gallop to find last block ending in xi
|
||||||
|
d:usz = block
|
||||||
|
d2 := undefined{usz}
|
||||||
|
while ((d2=d+d) <= rem and bxi{d2}) d = d2
|
||||||
|
l := (rem &~ (block-1)) - d; if (l > d) l = d
|
||||||
|
# Target is in [d,d+l); shrink l
|
||||||
|
while (l > block) {
|
||||||
|
h := (l/2) &~ (block-1)
|
||||||
|
m := d + h
|
||||||
|
if (bxi{m}) d = m
|
||||||
|
l -= h
|
||||||
|
}
|
||||||
|
overflow{d}
|
||||||
|
rem -= d; if (rem == 0) return{on}
|
||||||
|
i += d; xo += d; xi = load{xo}
|
||||||
}
|
}
|
||||||
x += b; n -= b
|
# Count the next block normally
|
||||||
|
if (rem > block) rem = block
|
||||||
|
count_adapt_runs{xo, r, rem}
|
||||||
|
rxi := load{r, xi}
|
||||||
|
if (rxi >= block) {
|
||||||
|
store{r, xi, rxi - block}
|
||||||
|
overflow{block}
|
||||||
|
}
|
||||||
|
i += rem
|
||||||
}
|
}
|
||||||
|
on
|
||||||
}
|
}
|
||||||
|
|
||||||
export{'simd_count_i8', count{i8}}
|
export{'simd_count_i8', count{i8}}
|
||||||
export{'simd_count_i16', count{i16}}
|
export{'simd_count_i16', count{i16}}
|
||||||
export{'simd_count_i32_i32', count_i32_i32}
|
export{'simd_count_i32_i32', count_i32_i32}
|
||||||
|
export{'si_count_sorted_i8', count_sorted{i8}}
|
||||||
|
export{'si_count_sorted_i16', count_sorted{i16}}
|
||||||
|
export{'si_count_sorted_i32', count_sorted{i32}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user