Merge pull request #18 from mlochbaum/master

Boolean prefix sum using AVX2 shuffling instead of pdep
This commit is contained in:
dzaima 2022-04-21 18:50:35 +03:00 committed by GitHub
commit dd979e172f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,34 +2,53 @@ include './base'
include './sse3'
include './avx'
include './avx2'
include './bmi2'
include './mask'
avx2_bcs32(x:*u64, r:*i32, l:u64) : void = {
rv:= *[8]u32~~r
xv:= *u8~~x
def step{x} = {
m:u32 = 2b0001_0001_0001_0001_0001_0001_0001_0001
a:= pdep{promote{u32,x},m}
d:= a*m # b:= a + a<<16; c:= b + b<<8; d:= c + c<<4
e:= broadcast{[8]u32, d}
f:= e >> make{[8]u32, iota{8}*4}
tup{f & broadcast{[8]u32, 0xf}, d, e}
}
xv:= *u32~~x
c:= broadcast{[8]u32, 0}
e:= l/8
@for(x in xv, r in rv over _ to e) {
def sr = step{x}
r = tupsel{0, sr}+c
# c+= broadcast{[8]u32, popc{x}}
# c+= broadcast{[8]u32, tupsel{1, sr}>>28}
c+= tupsel{2, sr}>>28
def tail{k,i} = i - i>>k<<k # Last k bits of i
def bit {k,i} = tail{1,i>>k}<<k
def ii32 = iota{32}; def bit{k}=bit{k,ii32}; def tail{k}=tail{k,ii32}
def sums{n} = (if (n==0) tup{0}; else { def s=sums{n-1}; merge{s,s+1} })
def sel8{v, t} = sel{[16]u8, v, make{[32]i8, t}}
def widen{v:T} = unpackQ{shuf{[4]u64, v, 4b3120}, broadcast{T, 0}}
def step{x:u32, i, store1} = {
b:= broadcast{[8]u32, x} >> make{[8]u32, 4*tail{1, iota{8}}}
s:= sel8{[32]u8~~b, ii32>>3 + bit{2}}
p:= s & make{[32]u8, (1<<(1+tail{2})) - 1} # Prefixes
d:= sel{[16]u8, make{[32]u8, merge{sums{4},sums{4}}}, [32]i8~~p}
d+= sel8{d, bit{2}*(1+bit{3}>>2)-1}
d+= sel8{d, bit{3}-1}
#d+= [32]u8~~shuf{[4]u64, [8]i32~~sel8{d, bit{3}<<4-1}, 4b1100}
j:= 4*i
def out{v, k} = each{out, widen{v}, 2*k+iota{2}}
def out{v0:[8]i32, k} = {
v := [8]u32~~v0 + c
if (tail{1,k}) c = sel{[8]u32, v, make{[8]i32, each{{i}=>7,iota{8}}}}
store1{rv, j+k, v}
}
out{[32]i8~~d, 0}
}
if (e*8 != l) {
r:= c+tupsel{0, step{load{xv, e}}}
maskstoreF{rv, maskOf{[8]i32, l - e*8}, e, r}
e:= l/32
@for (xv over i to e) {
step{xv, i, store}
}
if (e*32 != l) {
def st{p, j, v} = {
j8 := 8*j
if (j8+8 <= l) {
store{p, j, v}
} else {
if (j8 < l) maskstoreF{rv, maskOf{[8]i32, l - j8}, j, v}
return{}
}
}
step{load{xv, e}, e, st}
}
}
'avx2_bcs32' = avx2_bcs32
'avx2_bcs32' = avx2_bcs32