Merge pull request #89 from mlochbaum/pext

Boolean compress
This commit is contained in:
dzaima 2023-08-08 14:11:52 +03:00 committed by GitHub
commit a175c48104
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 70 deletions

View File

@ -2,8 +2,11 @@
// In the notes 𝕨 might indicate 𝕩 for Indices too // In the notes 𝕨 might indicate 𝕩 for Indices too
// Boolean 𝕨 (Where/Compress) general case based on result type width // Boolean 𝕨 (Where/Compress) general case based on result type width
// Size 1: pext // Size 1: compress 64-bit units, possibly SIMD
// Emulate if unavailable // pext if BMI2 is present
// Pairwise combination, SIMD if AVX2
// SIMD shift-by-offset if there's CLMUL but no AVX2
// SHOULD use with polynomial multiply in NEON
// COULD return boolean result from Where // COULD return boolean result from Where
// Size 8, 16, 32, 64: mostly table-based // Size 8, 16, 32, 64: mostly table-based
// Where: direct table lookup, widening for 16 and 32 if available // Where: direct table lookup, widening for 16 and 32 if available
@ -18,7 +21,7 @@
// None for 8-bit Where, too short // None for 8-bit Where, too short
// COULD try per-block adaptivity for 16-bit Compress // COULD try per-block adaptivity for 16-bit Compress
// Sparse if +´𝕨 is small, branchless unless it's very small // Sparse if +´𝕨 is small, branchless unless it's very small
// Chosen per-argument for 8, 16 and per-block for larger // Chosen per-argument for 1, 8, 16 and per-block for larger
// Careful when benchmarking, branch predictor has a long memory // Careful when benchmarking, branch predictor has a long memory
// Grouped if +´»⊸≠𝕨 is small, always branching // Grouped if +´»⊸≠𝕨 is small, always branching
// Chosen per-argument with a threshold that gives up early // Chosen per-argument with a threshold that gives up early
@ -439,32 +442,22 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
switch(xl) { switch(xl) {
default: r = compress_grouped(wp, x, wia, wsum, xt); break; default: r = compress_grouped(wp, x, wia, wsum, xt); break;
case 0: { case 0: {
u64* xp = bitarr_ptr(x); u64* rp; u64* xp = bitarr_ptr(x);
#if defined(__BMI2__) || SINGELI u64* rp; r = m_bitarrv(&rp,wsum);
r = m_bitarrv(&rp,wsum+128); a(r)->ia = wsum; #if SINGELI
u64 cw = 0; // current word if (wsum>=wia/si_thresh_compress_bool) {
u64 ro = 0; // offset in word where next bit should be written; never 64 si_compress_bool(wp, xp, rp, wia); break;
for (usz i=0; i<BIT_N(wia); i++) {
u64 wv = wp[i];
#if defined(__BMI2__)
u64 v = _pext_u64(xp[i], wv);
#else
u64 v = si_pext_u64(xp[i], wv);
#endif
u64 c = rand_popc64(wv);
cw|= v<<ro;
u64 ro2 = ro+c;
if (ro2>=64) {
*(rp++) = cw;
cw = ro? v>>(64-ro) : 0;
}
ro = ro2&63;
} }
if (ro) *rp = cw;
#else
r = m_bitarrv(&rp,wsum);
for (usz i=0, ri=0; i<wia; i++) { bitp_set(rp,ri,bitp_get(xp,i)); ri+= bitp_get(wp,i); }
#endif #endif
u64 o = 0;
usz j = 0;
for (usz i=0; i<BIT_N(wia); i++) {
for (u64 v=wp[i], x=xp[i]; v; v&=v-1) {
o = o>>1 | (x>>CTZ(v))<<63;
++j; if (j%64==0) rp[j/64-1] = o;
}
}
usz q=(-j)%64; if (q) rp[j/64] = o>>q;
break; break;
} }
#define COMPRESS_BLOCK_PREP(T, PREP) \ #define COMPRESS_BLOCK_PREP(T, PREP) \

View File

@ -69,12 +69,12 @@ def packQQ{{a, b}} = packQQ{a, b}
# arith # arith
def __mul{a:T,b:T & [16]i16==T} = emit{T, '_mm256_mullo_epi16', a, b} def __mul{a:T,b:T & w256i{T, 16}} = emit{T, '_mm256_mullo_epi16', a, b}
def mulHi{a:T,b:T & [16]i16==T} = emit{T, '_mm256_mulhi_epi16', a, b} def mulHi{a:T,b:T & [16]i16==T } = emit{T, '_mm256_mulhi_epi16', a, b}
def mulHi{a:T,b:T & [16]u16==T} = emit{T, '_mm256_mulhi_epu16', a, b} def mulHi{a:T,b:T & [16]u16==T } = emit{T, '_mm256_mulhi_epu16', a, b}
def __mul{a:T,b:T & [ 8]i32==T} = emit{T, '_mm256_mullo_epi32', a, b} def __mul{a:T,b:T & w256i{T, 32}} = emit{T, '_mm256_mullo_epi32', a, b}
def mul32{a:T,b:T & [ 4]i64==T} = emit{T, '_mm256_mul_epi32', a, b} # reads only low 32 bits of arguments def mul32{a:T,b:T & [ 4]i64==T } = emit{T, '_mm256_mul_epi32', a, b} # reads only low 32 bits of arguments
def mul32{a:T,b:T & [ 4]u64==T} = emit{T, '_mm256_mul_epu32', a, b} # reads only low 32 bits of arguments def mul32{a:T,b:T & [ 4]u64==T } = emit{T, '_mm256_mul_epu32', a, b} # reads only low 32 bits of arguments
def abs{a:T & w256i{T,8 }} = emit{T, '_mm256_abs_epi8', a} def abs{a:T & w256i{T,8 }} = emit{T, '_mm256_abs_epi8', a}
def abs{a:T & w256i{T,16}} = emit{T, '_mm256_abs_epi16', a} def abs{a:T & w256i{T,16}} = emit{T, '_mm256_abs_epi16', a}

View File

@ -286,11 +286,15 @@ export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh{1
export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh{1, i32}} export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh{1, i32}}
export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh{1, i64}} export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh{1, i64}}
def scalwidth{T} = if (isvec{T}) elwidth{T} else width{T}
# pext, or boolean compress # pext, or boolean compress
fn pext{T}(x:T, m:T) { def pext_width{} = if (hasarch{'AVX2'}) 4 else 1
def w = width{T} def thresh_bool{} = if (hasarch{'AVX2'}) 128 else 16
def mod{a} = a % (1<<w) def pext_popc{x:T, m:T} = {
def lowbits{k} = base{1<<k, cdiv{w,k}**1} def w = scalwidth{T}
def scal{v} = if (isvec{T}) T**v else v
def lowbits{w,k} = base{1<<k, cdiv{w,k}**1}
# At each step, x and z are split into groups of length k # At each step, x and z are split into groups of length k
# - z tells how many bits in the group are NOT used # - z tells how many bits in the group are NOT used
# - x contains the bits, with z zeros above # - x contains the bits, with z zeros above
@ -298,60 +302,117 @@ fn pext{T}(x:T, m:T) {
def build{k & k > 1} = { def build{k & k > 1} = {
def h = k>>1 # Increase size from h to k def h = k>>1 # Increase size from h to k
{x,z} := build{h} {x,z} := build{h}
def low = lowbits{k} # Low bit in each new group def low_s = lowbits{w,k} # Low bit in each new group
if (k <= 3) { def low = scal{low_s}
if (k == 2) {
z0 := z & low z0 := z & low
zm := z>>1 & low zm := z>>1 & low
if (k == 2) tup{ tup{ x - (x>>1 & z0), zm + z0 }
x - (x>>1 & z0), } else if (hasarch{'AVX2'} and isvec{T} and k >= 32) {
z0 + zm # We have variable shifts at these sizes
} else tup{ # Faster 1->3 jump, currently unused lh := scal{low_s*(1<<h - 1)}
x - ((x>>1&mod{low*3}) & (z|z0<<1)) - (x>>2 & (z & zm)), zl := z & lh
(z0 + zm) + (z>>2 & low) def S = re_el{ty_u{k}, T}
} tup{T~~(S~~(x&~lh) >> S~~zl) | (x&lh), T~~(S~~z >> h) + zl}
} else { } else {
# Shift high x group down by low z, then add halves of z
even:T = mod{low*(1<<h - 1)}
# SWAR shifter: shift x by sh*o, in length-k groups # SWAR shifter: shift x by sh*o, in length-k groups
def shift{sh, o, x} = { def shift{sh, o, x} = {
l := o & low; m := l<<k - l l := o & low; m := l<<k - l
s := (x & m)>>sh | (x &~ m) s := (x & m)>>sh | (x &~ m)
if (2*sh<=k/2) shift{2*sh, o>>1, s} else s if (2*sh<k/2) shift{2*sh, o>>1, s} else s
} }
# Shift high x group down by low z, then add halves of z
odd:T = scal{low_s*(1<<k - 1<<h)} # Top half
ze := z&~odd
z1 := ze + scal{low_s*(1<<(k-1) - 1)} # z-1, as signed k-bit
move := odd &~ (z1<<1) # Only groups where z>0 move
tup{ tup{
(x&even) | shift{1, z, x&~even}, (x&~move) | shift{1, z1, x&move}>>1,
if (k>4) (z + z>>h)&even else ((z&~even)>>h) + (z&even) (z&odd)>>h + ze
} }
} }
} }
# Finally, compose groups with regular shifts # Compose k/g groups with k/g-1 regular shifts
def g = 8 # 12 performs about the same def multi_shift{x, z, g, k, sc} = {
{b,z} := build{g} o := z * sc{lowbits{k,g}} # Offsets by prefix sum
o := z*lowbits{g} # Offsets by prefix sum def s = 1<<g - 1
def s = 1<<g - 1 def s0 = sc{s}
def gr{sh} = (b & mod{s<<sh}) >> (o>>(sh-g) & s) def oo{sh} = if (sh==g) z else o>>(sh-g) # Offset for group
fold{|, b&s, each{gr, g*slice{iota{cdiv{w,g}},1}}} def gr{sh} = (x & sc{s<<sh}) >> (oo{sh} & s0) # Shifted group
pe := fold{|, x&s0, each{gr, g*slice{iota{k/g},1}}}
tup{pe, o>>(k-g)}
}
def build{k==32 & hasarch{'AVX2'} & isvec{T}} = {
def S = re_el{ty_u{k}, T}
def c{T,vs} = each{{v}=>T~~v, vs}
c{T, multi_shift{...c{S, build{8}}, 8, k, {s}=>S**s}}
}
def build{k & ~isvec{T} & k > 8} = {
multi_shift{...build{8}, 8, k, {s}=>s}
}
# Final result
def {pe, z} = build{w}
tup{pe, scal{w} - z}
} }
fn pext{T & hasarch{'PCLMUL'} & T==u64}(xs:T, ms:T) { def pext_width {..._ & hasarch{'PCLMUL'} > hasarch{'AVX2'}} = 2
def num = lb{width{T}} def thresh_bool{..._ & hasarch{'PCLMUL'} > hasarch{'AVX2'}} = 32
def vec{s} = make{[2]T, s, 0} def pext_popc{x0:V, m0:V & hasarch{'PCLMUL'} & V==[2]u64} = {
m := vec{ms} def clmul{a, b} = zipLo{...@collect (j to 2) clmul{a,b,j}}
x := vec{xs} & m m := m0
x := x0 & m
d := ~m << 1 # One bit of the position difference at x d := ~m << 1 # One bit of the position difference at x
c := vec{1<<64-1} c := V**(1<<64-1)
@unroll (i to num) { @unroll (i to lb{scalwidth{V}}) {
def sh = 1 << i def sh = 1 << i
def shift_at{v, s} = { v = (v&~s) | (v&s)>>sh } def shift_at{v, s} = { v = (v&~s) | (v&s)>>sh }
p := clmul{d, c, 0} # xor-scan p := clmul{d, c} # xor-scan
d = d &~ p # Remove even bits d = d &~ p # Remove even bits
p &= m p &= m
shift_at{m, p} shift_at{m, p}
shift_at{x, p} shift_at{x, p}
} }
extract{x, 0} tup{x, @collect (j to 2) popc{extract{m0,j}}}
} }
fn pext{T & hasarch{'BMI2'}}(x:T, m:T) = pext{x, m} def pext_width {..._ & hasarch{'BMI2'}} = 1
def thresh_bool{..._ & hasarch{'BMI2'}} = 512
def pext_popc{x:T, m:T & hasarch{'BMI2'} & T==u64} = tup{pext{x, m}, popc{m}}
export{'si_pext_u64', pext{u64}} fn compress_bool(w:*u64, x:*u64, r:*u64, n:u64) : void = {
cw:u64 = 0; # current word
ro:u64 = 0; # offset in word where next bit should be written; never 64
def add_bits{{v, c}} = {
cw |= v<<ro
ro2 := ro+c
if (ro2 >= 64) {
store{r, 0, cw}; ++r
cw = 0; if (ro>0) cw = v>>(64-ro)
}
ro = ro2%64
}
def extract{t, i & istup{t}} = tupsel{i,t}
def v = pext_width{}
if (v > 1) {
def V = [v]u64
d := cdiv{n,64}; e := d/v
@for (w in *V~~w, x in *V~~x over i to cdiv{d,v}) {
vc := pext_popc{x, w}
def add{j} = add_bits{each{extract{., j}, vc}}
if (i < e) {
@unroll (j to v) add{j}
} else {
# last write: between 1 and v-1 words
m := d%v
def ar{j} = { add{j}; def jn=j+1; if (jn<v-1 and jn<m) ar{jn} }
ar{0}
}
}
} else {
@for (w, x over i to cdiv{n,64}) add_bits{pext_popc{x, w}}
}
if (ro > 0) store{r, 0, cw}
}
export{'si_compress_bool', compress_bool}
export{'si_thresh_compress_bool', u64~~thresh_bool{}}

View File

@ -119,7 +119,7 @@ def __sub{a:T,b:T & w128i{T, 16}} = emit{T, '_mm_sub_epi16', a, b}
def __sub{a:T,b:T & w128i{T, 32}} = emit{T, '_mm_sub_epi32', a, b} def __sub{a:T,b:T & w128i{T, 32}} = emit{T, '_mm_sub_epi32', a, b}
def __sub{a:T,b:T & w128i{T, 64}} = emit{T, '_mm_sub_epi64', a, b} def __sub{a:T,b:T & w128i{T, 64}} = emit{T, '_mm_sub_epi64', a, b}
def __mul{a:T,b:T & [8]i16==T} = emit{T, '_mm_mullo_epi16', a, b} def __mul{a:T,b:T & w128i{T, 16}} = emit{T, '_mm_mullo_epi16', a, b}
def mulHi{a:T,b:T & [8]i16==T} = emit{T, '_mm_mulhi_epi16', a, b} def mulHi{a:T,b:T & [8]i16==T} = emit{T, '_mm_mulhi_epi16', a, b}
def mulHi{a:T,b:T & [8]u16==T} = emit{T, '_mm_mulhi_epu16', a, b} def mulHi{a:T,b:T & [8]u16==T} = emit{T, '_mm_mulhi_epu16', a, b}
def mul32{a:T,b:T & [2]u64==T} = emit{T, '_mm_mul_epu32', a, b} # reads only low 32 bits of arguments def mul32{a:T,b:T & [2]u64==T} = emit{T, '_mm_mul_epu32', a, b} # reads only low 32 bits of arguments