Merge pull request #55 from mlochbaum/plusscanbool

Typed boolean prefix sums
This commit is contained in:
dzaima 2022-11-08 17:01:17 +02:00 committed by GitHub
commit bde4ac17cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 23 deletions

View File

@ -151,6 +151,36 @@ static i64 bit_diff(u64* x, u64 am) {
#include "../singeli/gen/scan.c"
#pragma GCC diagnostic pop
#endif
B slash_c1(B f, B x);
B scan_bit_sum(B x, u64* xp, u64 ia, u64 xs) { // consumes x
u8 re = xs<=I8_MAX? el_i8 : xs<=I16_MAX? el_i16 : el_i32;
if (xs < ia/128) {
B ones = slash_c1(m_f64(0), x);
MAKE_MUT(r0, ia) mut_init(r0, re); MUTG_INIT(r0);
SGetU(ones)
usz ri = 0;
for (usz i = 0; i < xs; i++) {
usz e = o2s(GetU(ones, i));
mut_fillG(r0, ri, m_i32(i), e-ri);
ri = e;
}
if (ri<ia) mut_fillG(r0, ri, m_i32(xs), ia-ri);
decG(ones);
return mut_fv(r0);
}
B r;
void* rp = m_tyarrv(&r, elWidth(re), ia, el2t(re));
#if SINGELI
#define SUM(W,T) avx2_bcs##W(xp, rp, ia);
#else
#define SUM(W,T) { T c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); ((T*)rp)[i]=c; } }
#endif
#define CASE(W) case el_i##W: SUM(W, i##W) break;
switch (re) { default:UD; CASE(8) CASE(16) CASE(32) }
#undef CASE
#undef SUM
decG(x); return r;
}
#if !USE_VALGRIND
static u64 vg_rand(u64 x) { return x; }
@ -167,15 +197,14 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
u8 rtid = v(f)->flags-1;
if (xe==el_bit) {
u64* xp=bitarr_ptr(x);
if (rtid==n_add && ia<I32_MAX) { i32* rp; B r=m_i32arrv(&rp, ia);
#if SINGELI
avx2_bcs32(xp, rp, ia);
#else
i32 c=0; for (usz i=0; i<ia; i++) { c+= bitp_get(xp,i); rp[i]=c; }
#endif
decG(x); return r; }
if (rtid==n_or | rtid==n_ceil ) { u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); u64 xi; usz i=0; while(i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]= 0 ; decG(x); return r; }
if (rtid==n_and | rtid==n_mul | rtid==n_floor) { u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); u64 xi; usz i=0; while(i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL; decG(x); return r; }
if (rtid==n_add) {
u64 xs = bit_sum(xp, ia);
if (xs>I32_MAX) goto base;
if (xs<=1) { if (xs==0) return x; goto bit_or; }
return FL_SET(scan_bit_sum(x, xp, ia, xs), fl_asc|fl_squoze);
}
if (rtid==n_or | rtid==n_ceil ) { bit_or:; u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); u64 xi; usz i=0; while(i<n) if ((xi= vg_rand(xp[i]))!=0) { rp[i] = -(xi&-xi) ; i++; while(i<n) rp[i++] = ~0LL; break; } else rp[i++]= 0 ; decG(x); return r; }
if (rtid==n_and | rtid==n_mul | rtid==n_floor) { u64* rp; B r=m_bitarrv(&rp,ia); usz n=BIT_N(ia); u64 xi; usz i=0; while(i<n) if ((xi=~vg_rand(xp[i]))!=0) { rp[i] = (xi&-xi)-1; i++; while(i<n) rp[i++] = 0 ; break; } else rp[i++]=~0LL; decG(x); return r; }
if (rtid==n_ne) { B r=scan_ne(0, xp, ia); decG(x); return r; }
goto base;
}

View File

@ -86,28 +86,38 @@ avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {
'avx2_scan_pluswrap_u32' = avx2_scan_assoc_0{u32, +}
# Boolean cumulative sum
avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
rv:= *[8]u32~~r
avx2_bcs{T}(x:*u64, r:*T, l:u64) : void = {
def U = ty_u{T}
def w = width{T}
def vl= 256 / w
def V = [vl]U
rv:= *V~~r
xv:= *u32~~x
c:= broadcast{[8]u32, 0}
c:= broadcast{V, 0}
def ii32 = iota{32}; def bit{k}=bit{k,ii32}; def tail{k}=tail{k,ii32}
def sums{n} = (if (n==0) tup{0}; else { def s=sums{n-1}; merge{s,s+1} })
def widen{v:T} = unpackQ{shuf{[4]u64, v, 4b3120}, broadcast{T, 0}}
def step{x:u32, i, store1} = {
def sumlanes{x:u32} = {
b:= broadcast{[8]u32, x} >> make{[8]u32, 4*tail{1, iota{8}}}
s:= sel8{[32]u8~~b, ii32>>3 + bit{2}}
p:= s & make{[32]u8, (1<<(1+tail{2})) - 1} # Prefixes
d:= sel{[16]u8, make{[32]u8, merge{sums{4},sums{4}}}, [32]i8~~p}
d+= sel8{d, bit{2}*(1+bit{3}>>2)-1}
d+= sel8{d, bit{3}-1}
#d+= [32]u8~~shuf{[4]u64, [8]i32~~sel8{d, bit{3}<<4-1}, 4b1100}
j:= 4*i
d + sel8{d, bit{3}-1}
}
def step{x:u32, i, store1} = {
d:= sumlanes{x}
if (w==8) d+= [32]u8~~shuf{[4]u64, [8]i32~~sel8{d, bit{3}<<4-1}, 4b1100}
j:= (w/8)*i
def out{v, k} = each{out, widen{v}, 2*k+iota{2}}
def out{v0:[8]i32, k} = {
v := [8]u32~~v0 + c
if (tail{1,k}) c = sel{[8]u32, v, make{[8]i32, broadcast{8, 7}}}
def out{v0:[vl]T, k} = {
v := V~~v0 + c
# Update carry at the lane boundary
if (w!=32 or tail{1,k}) {
c = sel{[8]u32, spread{v}, make{[8]i32, broadcast{8, 7}}}
}
store1{rv, j+k, v}
}
out{[32]i8~~d, 0}
@ -120,15 +130,17 @@ avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
if (e*32 != l) {
def st{p, j, v} = {
j8 := 8*j
if (j8+8 <= l) {
jv := vl*j
if (jv+vl <= l) {
store{p, j, v}
} else {
if (j8 < l) maskstoreF{rv, maskOf{[8]i32, l - j8}, j, v}
if (jv < l) maskstoreF{rv, maskOf{V, l - jv}, j, v}
return{}
}
}
step{load{xv, e}, e, st}
}
}
'avx2_bcs32' = avx2_bcs32
'avx2_bcs8' = avx2_bcs{i8}
'avx2_bcs16' = avx2_bcs{i16}
'avx2_bcs32' = avx2_bcs{i32}