Return max from AVX2 counting function
This commit is contained in:
parent
4b18466ae2
commit
0e5b98c491
@ -85,6 +85,9 @@
|
|||||||
#define SINGELI_FILE constrep
|
#define SINGELI_FILE constrep
|
||||||
#include "../utils/includeSingeli.h"
|
#include "../utils/includeSingeli.h"
|
||||||
|
|
||||||
|
#define SINGELI_FILE count
|
||||||
|
#include "../utils/includeSingeli.h"
|
||||||
|
|
||||||
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
|
extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3);
|
||||||
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
|
extern void (*const avx2_scan_pluswrap_u16)(uint16_t* v0,uint16_t* v1,uint64_t v2,uint16_t v3);
|
||||||
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
|
extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3);
|
||||||
@ -850,15 +853,6 @@ B slash_c2(B t, B w, B x) {
|
|||||||
return c2rt(slash, w, x);
|
return c2rt(slash, w, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if SINGELI_AVX2
|
|
||||||
#define SINGELI_FILE count
|
|
||||||
#include "../utils/includeSingeli.h"
|
|
||||||
#define SINGELI_COUNT_OR(N) \
|
|
||||||
if (N==8) avx2_count_i8(t, (u8*)xp, xia); else
|
|
||||||
#else
|
|
||||||
#define SINGELI_COUNT_OR(N)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
B slash_im(B t, B x) {
|
B slash_im(B t, B x) {
|
||||||
if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be an array");
|
if (!isArr(x) || RNK(x)!=1) thrM("/⁼: Argument must be an array");
|
||||||
u8 xe = TI(x,elType);
|
u8 xe = TI(x,elType);
|
||||||
@ -912,10 +906,10 @@ B slash_im(B t, B x) {
|
|||||||
usz m=1<<N; \
|
usz m=1<<N; \
|
||||||
if (xia < m/2) { \
|
if (xia < m/2) { \
|
||||||
IIND_INT(N) \
|
IIND_INT(N) \
|
||||||
} else { \
|
} else SINGELI_COUNT_OR(N) { \
|
||||||
TALLOC(usz, t, m); \
|
TALLOC(usz, t, m); \
|
||||||
for (usz j=0; j<m/2; j++) t[j]=0; \
|
for (usz j=0; j<m/2; j++) t[j]=0; \
|
||||||
SINGELI_COUNT_OR(N) for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \
|
for (usz i=0; i<xia; i++) t[(u##N)xp[i]]++; \
|
||||||
t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \
|
t[m/2]=xia; usz ria=0; for (u64 s=0; s<xia; ria++) s+=t[ria]; \
|
||||||
if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers"); \
|
if (ria>m/2) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||||
i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
||||||
@ -924,8 +918,23 @@ B slash_im(B t, B x) {
|
|||||||
} \
|
} \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
#if SINGELI_AVX2
|
||||||
|
#define SINGELI_COUNT_OR(N) if (N==8) { \
|
||||||
|
TALLOC(usz, t, m/2); \
|
||||||
|
for (usz j=0; j<m/2; j++) t[j]=0; \
|
||||||
|
i8 max = avx2_count_i8(t, (i8*)xp, xia); \
|
||||||
|
if (max < 0) thrM("/⁼: Argument cannot contain negative numbers"); \
|
||||||
|
usz ria=max+1; \
|
||||||
|
i32* rp; r = m_i32arrv(&rp, ria); for (usz i=0; i<ria; i++) rp[i]=t[i]; \
|
||||||
|
TFREE(t); \
|
||||||
|
r = num_squeeze(r); \
|
||||||
|
} else
|
||||||
|
#else
|
||||||
|
#define SINGELI_COUNT_OR(N)
|
||||||
|
#endif
|
||||||
CASE_SMALL(8) CASE_SMALL(16)
|
CASE_SMALL(8) CASE_SMALL(16)
|
||||||
#undef CASE_SMALL
|
#undef CASE_SMALL
|
||||||
|
#undef SINGELI_COUNT_OR
|
||||||
case el_i32: { i32* xp = i32any_ptr(x); IIND_INT(32) r = num_squeeze(r); break; }
|
case el_i32: { i32* xp = i32any_ptr(x); IIND_INT(32) r = num_squeeze(r); break; }
|
||||||
#undef IIND_INT
|
#undef IIND_INT
|
||||||
case el_f64: {
|
case el_f64: {
|
||||||
|
|||||||
@ -13,7 +13,8 @@ def max{a:i16, b:i16} = minmax{>, a, b}
|
|||||||
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
|
def inc{ptr, ind, v} = store{ptr, ind, v + load{ptr, ind}}
|
||||||
def inc{ptr, ind} = inc{ptr, ind, 1}
|
def inc{ptr, ind} = inc{ptr, ind, 1}
|
||||||
|
|
||||||
fn count{T}(tab:*usz, x:*ty_u{T}, n:u64) : u1 = {
|
# Write counts /⁼x to tab and return ⌈´x
|
||||||
|
fn count{T}(tab:*usz, x:*T, n:u64) : T = {
|
||||||
def vbits = 256
|
def vbits = 256
|
||||||
def vec = vbits/width{T}
|
def vec = vbits/width{T}
|
||||||
def uT = ty_u{T}
|
def uT = ty_u{T}
|
||||||
@ -21,25 +22,30 @@ fn count{T}(tab:*usz, x:*ty_u{T}, n:u64) : u1 = {
|
|||||||
def block = (2048*8) / vbits # Target vectors per block
|
def block = (2048*8) / vbits # Target vectors per block
|
||||||
def b_max = block + block/4 # Last block max length
|
def b_max = block + block/4 # Last block max length
|
||||||
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
assert{b_max < 1<<width{T}} # Don't overflow count in vector section
|
||||||
|
mx:T = -1 # Maximum of x
|
||||||
i:u64 = 0
|
i:u64 = 0
|
||||||
while (i < n) {
|
while (i < n) {
|
||||||
# Number of elements to handle in this iteration
|
# Number of elements to handle in this iteration
|
||||||
r:u64 = n - i; if (r > vec*b_max) r = vec*block
|
r:u64 = n - i; if (r > vec*b_max) r = vec*block
|
||||||
b := r / vec # Vector case does b full vectors if it runs
|
b := r / vec # Vector case does b full vectors if it runs
|
||||||
|
rv:= b * vec
|
||||||
r0:u64 = 0 # Elements actually handled by vector case
|
r0:u64 = 0 # Elements actually handled by vector case
|
||||||
|
|
||||||
# Find range to check for suitability
|
# Find range to check for suitability; return a negative if found
|
||||||
xv := *V~~x
|
xv := *V~~x
|
||||||
jv := load{xv}; mv := jv
|
jv := load{xv}; mv := jv
|
||||||
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
|
@for (xv over _ from 1 to b) { jv = min{jv, xv}; mv = max{mv, xv} }
|
||||||
|
@for (x over _ from rv to r) { if (x<0) return{x}; if (x>mx) mx=x }
|
||||||
jt := fold{min, jv}
|
jt := fold{min, jv}
|
||||||
mt := fold{max, mv} - jt # Counts needed (last one's implicit)
|
mt := fold{max, mv}
|
||||||
if (jt < 0) return{1} # Negative number found!
|
if (jt < 0) return{jt}
|
||||||
|
if (mt > mx) mx = mt
|
||||||
|
|
||||||
if (mt <= 48) {
|
nc := mt - jt # Number of counts to perform: last is implicit
|
||||||
r0 = b * vec
|
if (nc <= 48) {
|
||||||
|
r0 = rv
|
||||||
j0 := promote{u64, uT~~jt} # Starting count
|
j0 := promote{u64, uT~~jt} # Starting count
|
||||||
m := promote{u64, uT~~mt} # Number of iterations
|
m := promote{u64, uT~~nc} # Number of iterations
|
||||||
total := trunc{usz, r0} # To compute last count
|
total := trunc{usz, r0} # To compute last count
|
||||||
def count_each{js, num} = {
|
def count_each{js, num} = {
|
||||||
j := @collect (k to num) js+k
|
j := @collect (k to num) js+k
|
||||||
@ -63,7 +69,7 @@ fn count{T}(tab:*usz, x:*ty_u{T}, n:u64) : u1 = {
|
|||||||
i += r
|
i += r
|
||||||
x += r
|
x += r
|
||||||
}
|
}
|
||||||
0
|
mx
|
||||||
}
|
}
|
||||||
|
|
||||||
export{'avx2_count_i8', count{i8}}
|
export{'avx2_count_i8', count{i8}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user