Merge pull request #88 from mlochbaum/compress

Singeli where/compress
This commit is contained in:
dzaima 2023-07-19 20:46:52 +03:00 committed by GitHub
commit 0a4a0f3e2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 422 additions and 94 deletions

View File

@ -216,7 +216,7 @@ po ← { # parsed options
has Lowercase GetOpt "has" has Lowercase GetOpt "has"
has has ("slow-pdep"<has)/"bmi2" has has ("slow-pdep"<has)/"bmi2"
{𝕊: "Error: Unsupported 'has' options; options:"1", "¨𝕩}_assert_(´has) "avx2""bmi2""slow-pdep" {𝕊: "Error: Unsupported 'has' options; options:"1", "¨𝕩}_assert_(´has) "avx2""bmi2""pclmul""slow-pdep"
{𝕊: "Error: Cannot have 'has' options on architecture '"arch"'; add an argument of "compat"arch=x86-64""target_arch=x86-64"}_assert_¬ (arch"x86-64") 0has {𝕊: "Error: Cannot have 'has' options on architecture '"arch"'; add an argument of "compat"arch=x86-64""target_arch=x86-64"}_assert_¬ (arch"x86-64") 0has
avx2 (arch"x86-64") singeli native "avx2"<has avx2 (arch"x86-64") singeli native "avx2"<has
@ -638,7 +638,7 @@ cachedBin‿linkerCache ← {
"2..""src/builtins/select.c""select", "2..""src/builtins/scan.c""scan", "2..""src/builtins/select.c""select", "2..""src/builtins/scan.c""scan",
"2a.""src/builtins/slash.c""constrep", "2..""src/builtins/scan.c""neq", "2a.""src/builtins/slash.c""constrep", "2..""src/builtins/scan.c""neq",
"2..""src/builtins/slash.c""slash", "2..""src/builtins/slash.c""count" "xag""src/builtins/slash.c""slash", "2..""src/builtins/slash.c""count"
objs objs
@ -672,7 +672,7 @@ cachedBin‿linkerCache ← {
singeliArgs po.singeliFlags"-l", "gen="AtRoot singeliCache.folder, "-c", "usz=u"•Repr po.usz{ singeliArgs po.singeliFlags"-l", "gen="AtRoot singeliCache.folder, "-c", "usz=u"•Repr po.usz{
po.native? ; po.native? ;
"-a" ({"x86-64":"X86_64"; "aarch64":"AARCH64"; "none"} po.arch) ','¨ Uppercase "avx2""bmi2"/po.has "-a" ({"x86-64":"X86_64"; "aarch64":"AARCH64"; "none"} po.arch) ','¨ Uppercase "avx2""bmi2""pclmul"/po.has
} }
{𝕊: "Singeli args: "•Repr singeliArgs} _verboseLog @ {𝕊: "Singeli args: "•Repr singeliArgs} _verboseLog @
singeliObjs {MakeSingeliInv singeliArgs, {𝕊:UpdateSubmodule po.singeliDir}, singeliCache, 𝕩, "src/singeli/src/"•file.At 𝕩".singeli", (𝕩"dyarith")/gaRule}¨ 1¨singeliMap singeliObjs {MakeSingeliInv singeliArgs, {𝕊:UpdateSubmodule po.singeliDir}, singeliCache, 𝕩, "src/singeli/src/"•file.At 𝕩".singeli", (𝕩"dyarith")/gaRule}¨ 1¨singeliMap

View File

@ -2,12 +2,16 @@
// In the notes 𝕨 might indicate 𝕩 for Indices too // In the notes 𝕨 might indicate 𝕩 for Indices too
// Boolean 𝕨 (Where/Compress) general case based on result type width // Boolean 𝕨 (Where/Compress) general case based on result type width
// COULD use AVX-512 // Size 1: pext
// Size 1: pext, or bit-at-a-time // Emulate if unavailable
// SHOULD emulate pext if unavailable
// COULD return boolean result from Where // COULD return boolean result from Where
// Size 8, 16: pdep/pext, or branchless // Size 8, 16, 32, 64: mostly table-based
// SHOULD try vector lookup-shuffle if unavailable or old AMD // Where: direct table lookup, widening for 16 and 32 if available
// Compress: table lookup plus shuffle
// AVX2 permutevar8x32 for 32 and 64 if available
// Sparse method using table-based Where fills in if no shuffle
// SHOULD implement for NEON
// AVX-512: compress instruction, separate store not compressstore
// Size 32, 64: 16-bit indices from where_block_u16 // Size 32, 64: 16-bit indices from where_block_u16
// Other sizes: always used grouped code // Other sizes: always used grouped code
// Adaptivity based on 𝕨 statistics // Adaptivity based on 𝕨 statistics
@ -70,15 +74,19 @@
#define _pdep_u64 vg_pdep_u64 #define _pdep_u64 vg_pdep_u64
#else #else
#define vg_loadLUT64(p, i) p[i] #define vg_loadLUT64(p, i) p[i]
#define rand_popc64(X) POPC(X)
#endif #endif
static void storeu_u64(u64* p, u64 v) { memcpy(p, &v, 8); } #endif
static u64 loadu_u64(u64* p) { u64 v; memcpy(&v, p, 8); return v; }
#if SINGELI_AVX2 #if !USE_VALGRIND
#define SINGELI_FILE slash #define rand_popc64(X) POPC(X)
#include "../utils/includeSingeli.h" #endif
#endif static void storeu_u64(u64* p, u64 v) { memcpy(p, &v, 8); }
static u64 loadu_u64(u64* p) { u64 v; memcpy(&v, p, 8); return v; }
#if SINGELI
#define SINGELI_FILE slash
#include "../utils/includeSingeli.h"
#endif #endif
#if SINGELI_AVX2 || SINGELI_NEON #if SINGELI_AVX2 || SINGELI_NEON
@ -159,8 +167,8 @@ static void bsp_u16(u64* src, u16* dst, usz len, usz sum) {
static void where_block_u16(u64* src, u16* dst, usz len, usz sum) { static void where_block_u16(u64* src, u16* dst, usz len, usz sum) {
assert(len <= bsp_max); assert(len <= bsp_max);
#if SINGELI_AVX2 && FAST_PDEP #if SINGELI
if (sum >= len/8) bmipopc_1slash16(src, (i16*)dst, len); if (sum >= len/si_thresh_1slash16) si_1slash16(src, (i16*)dst, len, sum);
#else #else
if (sum >= len/4+len/8) WHERE_DENSE(src, dst, len, 0); if (sum >= len/4+len/8) WHERE_DENSE(src, dst, len, 0);
#endif #endif
@ -232,19 +240,17 @@ static B where(B x, usz xia, u64 s) {
u64* xp = bitarr_ptr(x); u64* xp = bitarr_ptr(x);
usz q=xia%64; if (q) xp[xia/64] &= ((u64)1<<q) - 1; usz q=xia%64; if (q) xp[xia/64] &= ((u64)1<<q) - 1;
if (xia <= 128) { if (xia <= 128) {
#if SINGELI_AVX2 && FAST_PDEP #if SINGELI
i8* rp = m_tyarrvO(&r, 1, s, t_i8arr, 8); i8* rp = m_tyarrv(&r, 1, s, t_i8arr);
bmipopc_1slash8(xp, rp, xia); si_1slash8(xp, rp, xia, s);
FINISH_OVERALLOC_A(r, s, 8);
#else #else
i8* rp; r=m_i8arrv(&rp,s); WHERE_SPARSE(xp,rp,s,0,); i8* rp; r=m_i8arrv(&rp,s); WHERE_SPARSE(xp,rp,s,0,);
#endif #endif
} else if (xia <= 32768) { } else if (xia <= 32768) {
#if SINGELI_AVX2 && FAST_PDEP #if SINGELI
if (s >= xia/8) { if (s >= xia/si_thresh_1slash16) {
i16* rp = m_tyarrvO(&r, 2, s, t_i16arr, 16); i16* rp = m_tyarrv(&r, 2, s, t_i16arr);
bmipopc_1slash16(xp, rp, xia); si_1slash16(xp, rp, xia, s);
FINISH_OVERALLOC_A(r, s*2, 16);
} }
#else #else
if (s >= xia/4+xia/8) { if (s >= xia/4+xia/8) {
@ -273,10 +279,9 @@ static B where(B x, usz xia, u64 s) {
} else { } else {
bs = bit_sum(xp,b); bs = bit_sum(xp,b);
} }
#if SINGELI_AVX2 && FAST_PDEP #if SINGELI
if (bs >= b/8+b/16) { if (bs >= b/si_thresh_1slash32) {
bmipopc_1slash16(xp, buf, b); si_1slash32(xp, i, rq, b, bs);
for (usz j=0; j<bs; j++) rq[j] = i+buf[j];
} }
#else #else
if (bs >= b/2) { if (bs >= b/2) {
@ -359,36 +364,36 @@ B grade_bool(B x, usz xia, bool up) {
u64* xp = bitarr_ptr(x); u64* xp = bitarr_ptr(x);
u64 sum = bit_sum(xp, xia); u64 sum = bit_sum(xp, xia);
u64 l0 = up? xia-sum : sum; // Length of first set of indices u64 l0 = up? xia-sum : sum; // Length of first set of indices
#if SINGELI_AVX2 && FAST_PDEP #if SINGELI
if (xia < 16) { BRANCHLESS_GRADE(i8) } if (xia < 16) { BRANCHLESS_GRADE(i8) }
else if (xia <= 1<<15) { else if (xia <= 1<<15) {
B notx = bit_negate(incG(x)); B notx = bit_negate(incG(x));
u64* xp0 = bitarr_ptr(notx); u64* xp0 = bitarr_ptr(notx);
u64* xp1 = xp; u64* xp1 = xp;
u64 q=xia%64; if (q) { usz e=xia/64; u64 m=((u64)1<<q)-1; xp0[e]&=m; xp1[e]&=m; }
if (!up) { u64* t=xp1; xp1=xp0; xp0=t; } if (!up) { u64* t=xp1; xp1=xp0; xp0=t; }
#define BMI_GRADE(W) \ #define SI_GRADE(W) \
i##W* rp = m_tyarrvO(&r, W/8, xia, t_i##W##arr, W); \ i##W* rp = m_tyarrv(&r, W/8, xia, t_i##W##arr); \
bmipopc_1slash##W(xp0, rp , xia); \ si_1slash##W(xp0, rp , xia, l0 ); \
bmipopc_1slash##W(xp1, rp+l0, xia); \ si_1slash##W(xp1, rp+l0, xia, xia-l0);
FINISH_OVERALLOC_A(r, xia*(W/8), W); if (xia <= 128) { SI_GRADE(8) } else { SI_GRADE(16) }
if (xia <= 128) { BMI_GRADE(8) } else { BMI_GRADE(16) } #undef SI_GRADE
#undef BMI_GRADE
decG(notx); decG(notx);
} else if (xia <= 1ull<<31) { } else if (xia <= 1ull<<31) {
i32* rp0; r = m_i32arrv(&rp0, xia); i32* rp0; r = m_i32arrv(&rp0, xia);
i32* rp1 = rp0 + l0; i32* rp1 = rp0 + l0;
if (!up) { i32* t=rp1; rp1=rp0; rp0=t; } if (!up) { i32* t=rp1; rp1=rp0; rp0=t; }
usz b = 256; TALLOC(u8, buf, b); usz b = 256;
u64 xp0[4]; // 4 ≡ b/64 u64 xp0[4]; // 4 ≡ b/64
u64* xp1 = xp; u64* xp1 = xp;
for (usz i=0; i<xia; i+=b) { for (usz i=0; i<xia; i+=b) {
for (usz j=0; j<BIT_N(b); j++) xp0[j] = ~xp1[j]; for (usz j=0; j<BIT_N(b); j++) xp0[j] = ~xp1[j];
usz b2 = b>xia-i? xia-i : b; usz b2 = b>xia-i? xia-i : b;
usz s0=bit_sum(xp0,b2); bmipopc_1slash8(xp0, (i8*)buf, b2); for (usz j=0; j<s0; j++) *rp0++ = i+buf[j]; if (b2<b) { u64 q=b2%64; usz e=b2/64; u64 m=((u64)1<<q)-1; xp0[e]&=m; xp1[e]&=m; }
usz s1=b2-s0; bmipopc_1slash8(xp1, (i8*)buf, b2); for (usz j=0; j<s1; j++) *rp1++ = i+buf[j]; usz s0=bit_sum(xp0,b2); si_1slash32(xp0, i, rp0, b2, s0); rp0+=s0;
usz s1=b2-s0; si_1slash32(xp1, i, rp1, b2, s1); rp1+=s1;
xp1+= b2/64; xp1+= b2/64;
} }
TFREE(buf);
} }
#else #else
if (xia <= 128) { BRANCHLESS_GRADE(i8) } if (xia <= 128) { BRANCHLESS_GRADE(i8) }
@ -437,13 +442,17 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
default: r = compress_grouped(wp, x, wia, wsum, xt); break; default: r = compress_grouped(wp, x, wia, wsum, xt); break;
case 0: { case 0: {
u64* xp = bitarr_ptr(x); u64* rp; u64* xp = bitarr_ptr(x); u64* rp;
#if defined(__BMI2__) #if defined(__BMI2__) || SINGELI
r = m_bitarrv(&rp,wsum+128); a(r)->ia = wsum; r = m_bitarrv(&rp,wsum+128); a(r)->ia = wsum;
u64 cw = 0; // current word u64 cw = 0; // current word
u64 ro = 0; // offset in word where next bit should be written; never 64 u64 ro = 0; // offset in word where next bit should be written; never 64
for (usz i=0; i<BIT_N(wia); i++) { for (usz i=0; i<BIT_N(wia); i++) {
u64 wv = wp[i]; u64 wv = wp[i];
#if defined(__BMI2__)
u64 v = _pext_u64(xp[i], wv); u64 v = _pext_u64(xp[i], wv);
#else
u64 v = si_pext_u64(xp[i], wv);
#endif
u64 c = rand_popc64(wv); u64 c = rand_popc64(wv);
cw|= v<<ro; cw|= v<<ro;
u64 ro2 = ro+c; u64 ro2 = ro+c;
@ -474,25 +483,36 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
TFREE(buf) TFREE(buf)
#define COMPRESS_BLOCK(T) COMPRESS_BLOCK_PREP(T, ) #define COMPRESS_BLOCK(T) COMPRESS_BLOCK_PREP(T, )
#define WITH_SPARSE(W, CUTOFF, DENSE) { \ #define WITH_SPARSE(W, CUTOFF, DENSE) { \
i##W *xp=tyany_ptr(x), *rp; \ i##W *xp=tyany_ptr(x), *rp; \
if (wsum<wia/CUTOFF) { rp=m_tyarrv(&r,W/8,wsum,xt); COMPRESS_BLOCK(i##W); } \ if (CUTOFF!=1) { \
else if (groups_lt(wp,wia, wia/128)) r = compress_grouped(wp, x, wia, wsum, xt); \ if (wsum<wia/CUTOFF) { rp=m_tyarrv(&r,W/8,wsum,xt); COMPRESS_BLOCK(i##W); } \
else { DENSE; } \ else if (groups_lt(wp,wia, wia/128)) r = compress_grouped(wp, x, wia, wsum, xt); \
else { DENSE; } \
} else { \
if (wsum>=wia/8 && groups_lt(wp,wia, wia/W)) r = compress_grouped(wp, x, wia, wsum, xt); \
else { rp=m_tyarrv(&r,W/8,wsum,xt); COMPRESS_BLOCK(i##W); } \
} \
break; } break; }
#if SINGELI_AVX2 && FAST_PDEP #if SINGELI
case 3: WITH_SPARSE( 8, 32, rp=m_tyarrvO(&r,1,wsum,xt, 8); bmipopc_2slash8 (wp, xp, rp, wia); FINISH_OVERALLOC_A(r, wsum, 8)) #define DO(W) \
case 4: WITH_SPARSE(16, 16, rp=m_tyarrvO(&r,2,wsum,xt, 16); bmipopc_2slash16(wp, xp, rp, wia); FINISH_OVERALLOC_A(r, wsum*2, 16)) WITH_SPARSE(W, si_thresh_2slash##W, rp=m_tyarrv(&r,W/8,wsum,xt); si_2slash##W(wp, xp, rp, wia, wsum))
case 3: DO(8) case 4: DO(16) case 5: DO(32)
case 6: if (TI(x,elType)!=el_B) {
DO(64)
} // else follows
#undef DO
#else #else
case 3: WITH_SPARSE( 8, 2, rp=m_tyarrv(&r,1,wsum,xt); for (usz i=0; i<wia; i++) { *rp = xp[i]; rp+= bitp_get(wp,i); }) case 3: WITH_SPARSE( 8, 2, rp=m_tyarrv(&r,1,wsum,xt); for (usz i=0; i<wia; i++) { *rp = xp[i]; rp+= bitp_get(wp,i); })
case 4: WITH_SPARSE(16, 2, rp=m_tyarrv(&r,2,wsum,xt); for (usz i=0; i<wia; i++) { *rp = xp[i]; rp+= bitp_get(wp,i); }) case 4: WITH_SPARSE(16, 2, rp=m_tyarrv(&r,2,wsum,xt); for (usz i=0; i<wia; i++) { *rp = xp[i]; rp+= bitp_get(wp,i); })
#endif
#undef WITH_SPARSE
#define BLOCK_OR_GROUPED(T) \ #define BLOCK_OR_GROUPED(T) \
if (wsum>=wia/8 && groups_lt(wp,wia, wia/16)) r = compress_grouped(wp, x, wia, wsum, xt); \ if (wsum>=wia/8 && groups_lt(wp,wia, wia/16)) r = compress_grouped(wp, x, wia, wsum, xt); \
else { T* xp=tyany_ptr(x); T* rp=m_tyarrv(&r,sizeof(T),wsum,xt); COMPRESS_BLOCK(T); } else { T* xp=tyany_ptr(x); T* rp=m_tyarrv(&r,sizeof(T),wsum,xt); COMPRESS_BLOCK(T); }
case 5: BLOCK_OR_GROUPED(i32) break; case 5: BLOCK_OR_GROUPED(i32) break;
case 6: case 6:
if (TI(x,elType)!=el_B) { BLOCK_OR_GROUPED(u64) } if (TI(x,elType)!=el_B) { BLOCK_OR_GROUPED(u64) }
#undef BLOCK_OR_GROUPED
#endif
#undef WITH_SPARSE
else { else {
B xf = getFillR(x); B xf = getFillR(x);
B* xp = arr_bptr(x); B* xp = arr_bptr(x);
@ -509,7 +529,6 @@ static B compress(B w, B x, usz wia, u8 xl, u8 xt) {
} }
} }
break; break;
#undef BLOCK_OR_GROUPED
#undef COMPRESS_BLOCK #undef COMPRESS_BLOCK
} }
ur xr = RNK(x); ur xr = RNK(x);

View File

@ -160,6 +160,9 @@ def minvalue{T & issigned{T}} = - (1<<(width{T}-1))
def maxvalue{T & issigned{T}} = (1<<(width{T}-1))-1 def maxvalue{T & issigned{T}} = (1<<(width{T}-1))-1
# base cases # base cases
def pdep{...x} = assert{'pdep not supported', show{...x}}
def pext{...x} = assert{'pext not supported', show{...x}}
def popcRand{...x} = assert{'popcRand not supported', show{...x}}
def andnz{...x} = assert{'andnz not supported', show{...x}} def andnz{...x} = assert{'andnz not supported', show{...x}}
def topBlend{...x} = assert{'topBlend not supported', show{...x}} def topBlend{...x} = assert{'topBlend not supported', show{...x}}
def topMask{...x} = assert{'topMask not supported', show{...x}} def topMask{...x} = assert{'topMask not supported', show{...x}}
@ -171,6 +174,8 @@ def unpackLo{...x} = assert{'unpackLo not supported', show{...x}}
def unpackHi{...x} = assert{'unpackHi not supported', show{...x}} def unpackHi{...x} = assert{'unpackHi not supported', show{...x}}
def unpackQ{...x} = assert{'unpackQ not supported', show{...x}} def unpackQ{...x} = assert{'unpackQ not supported', show{...x}}
def packQ{...x} = assert{'packQ not supported', show{...x}} def packQ{...x} = assert{'packQ not supported', show{...x}}
def shl{...x} = assert{'shl not supported', show{...x}}
def shr{...x} = assert{'shr not supported', show{...x}}
def __mulhi{...x} = assert{'__mulhi not supported', show{...x}} def __mulhi{...x} = assert{'__mulhi not supported', show{...x}}
def fold_addw{...x} = assert{'fold_addw not supported', show{...x}} def fold_addw{...x} = assert{'fold_addw not supported', show{...x}}
def vfold{...x} = assert{'vfold not supported', show{...x}} def vfold{...x} = assert{'vfold not supported', show{...x}}
@ -204,6 +209,7 @@ def extract{...x} = assert{'extract not supported', show{...x}}
def abs{...x} = assert{'abs not supported', show{...x}} def abs{...x} = assert{'abs not supported', show{...x}}
def homBlend{...x} = assert{'homBlend not supported', show{...x}} def homBlend{...x} = assert{'homBlend not supported', show{...x}}
def zip{...x} = assert{'zip not supported', show{...x}} def zip{...x} = assert{'zip not supported', show{...x}}
def clmul{...x} = assert{'clmul not supported', show{...x}}
def andnot{a, b & anyNum{a} & anyNum{b}} = a & ~b def andnot{a, b & anyNum{a} & anyNum{b}} = a & ~b
oper &~ andnot infix none 35 oper &~ andnot infix none 35

View File

@ -2,6 +2,3 @@ def pdep{x:u64, m:u64} = emit{u64, '_pdep_u64', x, m}
def pdep{x:u32, m:u32} = emit{u32, '_pdep_u32', x, m} def pdep{x:u32, m:u32} = emit{u32, '_pdep_u32', x, m}
def pext{x:u64, m:u64} = emit{u64, '_pext_u64', x, m} def pext{x:u64, m:u64} = emit{u64, '_pext_u64', x, m}
def pext{x:u32, m:u32} = emit{u32, '_pext_u32', x, m} def pext{x:u32, m:u32} = emit{u32, '_pext_u32', x, m}
def popcRand{x:T & isint{T} & width{T}==64} = emit{u8, 'rand_popc64', x} # under valgrind, return a random result in the range of possible ones
def popcRand{x:T & isint{T} & width{T}<=32} = emit{u8, 'rand_popc64', x}

View File

@ -0,0 +1,2 @@
def clmul{a:T, b:T, imm & w128i{T}} = emit{T, '_mm_clmulepi64_si128', a, b, imm}
def clmul{a, b} = clmul{a, b, 0}

View File

@ -1,7 +1,7 @@
include './base' include './base'
include './sse' include './sse'
include './clmul'
def clmul{a:T, b:T, imm & w128i{T}} = emit{T, '_mm_clmulepi64_si128', a, b, imm}
def unpacklo{a:T,b:T & T==[2]u64} = emit{T, '_mm_unpacklo_epi64', a, b} def unpacklo{a:T,b:T & T==[2]u64} = emit{T, '_mm_unpacklo_epi64', a, b}
fn clmul_scan_ne_any(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = { fn clmul_scan_ne_any(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {

View File

@ -1,51 +1,355 @@
include './base' include './base'
include './bmi2' if (hasarch{'X86_64'}) include './sse'
if (hasarch{'PCLMUL'}) include './clmul'
if (hasarch{'AVX2'}) { include './avx'; include './avx2' }
if (hasarch{'BMI2'}) include './bmi2'
if (hasarch{'AARCH64'}) include './neon'
if (hasarch{'AVX512F'}) {
local def mti{s,T} = merge{'_mm512_',s,'_epi',fmtnat{elwidth{T}}}
def load{a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n}
def make{T, xs & 512==width{T} & tuplen{xs}==vcount{T}} = {
def p = each{{c}=>promote{eltype{T},c},reverse{xs}}
emit{T, mti{'set',T}, ...p}
}
def iota{T & isvec{T} & 512==width{T}} = make{T, iota{vcount{T}}}
def broadcast{T, v & isvec{T} & 512==width{T}} = {
emit{T, mti{'set1',T}, promote{eltype{T},v}}
}
def __add{a:T,b:T & 512==width{T}} = emit{T, mti{'add',T}, a, b}
}
include './mask'
include 'util/tup'
def storeu{p:T, i, v:eltype{T} & *u64==T} = emit{void, 'storeu_u64', p+i, v} def storeu{p:T, i, v:eltype{T} & *u64==T} = emit{void, 'storeu_u64', p+i, v}
def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p} def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p}
def comp8{w:*u64, X, r:*i8, l:u64} = { def popcRand{x:T & isint{T} & width{T}==64} = emit{u8, 'rand_popc64', x} # under valgrind, return a random result in the range of possible ones
@for(w in *u8~~w over i to cdiv{l,8}) { def popcRand{x:T & isint{T} & width{T}<=32} = emit{u8, 'rand_popc64', x}
pc:= popc{w}
storeu{*u64~~r, 0, pext{promote{u64,X{}}, pdep{promote{u64, w}, cast{u64,0x0101010101010101}}*255}} # Table from l bits to w-bit indices, shifted left by s, and G applied afterwards
def maketab{l,w,s,G} = {
def bot = fold{
{t,k} => join{each{tup, t, G{k} + t<<w}},
tup{0},
reverse{iota{l}<<s}
}
# Store popcnt-1 in the high element
def top = (fold{bind{flat_table,+}, l**iota{2}} - 1)%(1<<(w-s))
top<<(l*w-w+s) | bot # Overlaps for all-1 value only
}
def maketab{l,w,s} = maketab{l,w,s,{x}=>x}
def maketab{l,w} = maketab{l,w,0}
itab:*u64 = maketab{8,8} # 256 elts, 2KB; shared by many methods
# Recover popcount, for when POPCNT isn't there
def has_popc = hasarch{'POPCNT'}
def tab_popc{i:I, w} = (i>>(width{I}-w) + 1) & (1<<w - 1)
def popc_alt{v, i, w} = if (has_popc) popc{v} else tab_popc{i, w}
# slash{c, T} defines:
# if c==1: w/x
# if c==0 & (T==i8 or T==i16): /w
# if c==0 & T==i32: x + /w (assumes x is a multiple of 8 for topper)
# if sum(w) < len/thresh{c,T}, sparse Where will be used
def arg{c,T} = if (c) *T else if (T==i32) T else tup{} # type of x
# Modifies the input variable r
# Assumes iter{} will increment r, by at most write_len
def for_special_buffered{r, write_len}{vars,begin,sum,iter} = {
assert{isreg{r}}; assert{begin==0}
def T = eltype{type{r}}; def tw = width{T}
def ov = write_len-1
buf := undefined{T, 2*(ov+1)}
r0 := r
end := r + sum - ov
i:u64 = 0; buf_used:u1 = 0
def restart = setlabel{}
while (r < end) {
iter{i, vars}
++i
}
if (not buf_used) {
end += buf - r + ov
if (buf < end) {
r0 = r
r = buf
buf_used = 1; goto{restart}
}
} else {
if (has_simd) {
def vc = arch_defvw/tw;
def R = [vc]T
@unroll ((ov/vc)>>0) if (end-buf>vc) { store{*R~~r0, 0, load{*R~~buf}}; r0+=vc; buf+=vc }
homMaskStoreF{*R~~r0, maskOf{R, end-buf}, load{*R~~buf}}
} else {
@for (r0, buf over u64~~(end-buf)) r0 = buf
}
}
}
# Assumes w is trimmed, so the last 1 appears at index l-1
# Unused because an index buffer and select is faster
def thresh{c, T} = 1
fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1
@for (i to l) {
store{r, 0, if (c) load{x,i} else if (T==i32) cast_i{T,i}+x else i}
r+= bitp_get{w,i}
}
}
def getter{c, V, x} = {
if (c) {
i:u64 = 0
{} => { v:=load{*V~~x, i}; ++i; v }
} else {
i := iota{V}
if (isreg{x}) i += V**cast_i{eltype{V},x}
ii := V**vcount{V}
{} => { v:=i; i+=ii; v }
}
}
# Top bits to convert 1-byte indices to 2 or 4
# These can only change between loop iterations, provided the
# given x for i32 is a multiple of the loop step
def topper{T, U, k, x} = {
def make_top{S} = to_el{S,U}**(if (T<i32) 0 else cast_i{S, x>>width{S}})
top := each{make_top, replicate{{S}=>S<T, tup{i8,i16}}}
def i_off = if (T<i32) 0 else { assert{x%k==0}; cast_i{u64, x/k} }
# Increment top vector when i*k passes width of bottom vector
def vb = lb{k}
def inc{i, {}} = {}
def inc{i, {t:V, ...ts}} = {
if ((i+1+i_off)%(1<<(elwidth{V}-vb)) == 0) { t += V**1; inc{i,ts} }
}
tup{top, inc}
}
# i8 & i16 /w; 64 bits/iter; SWAR
itab_4_16:*u64 = maketab{4,16} # 16 elts, 128B
def thresh{c==0, T==i8 } = 32
def thresh{c==0, T==i16} = 16
fn slash{c==0, T & T<=i16}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def n = 64/tw
def tab = if (tw==8) itab else itab_4_16
j:u64 = 0
def inc = base{1<<tw, n**n}
@for_special_buffered{r,8} (w in *u8~~w over sum) {
def rn = if (has_popc) r+popc{w} else 0
def step{w} = {
i := load{tab, w}
storeu{*u64~~r, 0, j + i}
r += popc_alt{w, i, tw}
j += inc
}
if (tw==8) { step{w} }
else { step{w&0xf}; step{w>>4} }
if (has_popc) r = rn # Shorter dependency chain
}
}
# i16 /w & i32 x+/w; 8 elts/iter; 64 bit table input, expanded to 128 or 256 via topper
def thresh{c==0, T==i16 & hasarch{'X86_64'}} = 32
def thresh{c==0, T==i32 & hasarch{'X86_64'}} = 16
fn slash{c==0, T & hasarch{'X86_64'} & i16<=T & T<=i32}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [16]i8
j := I**(if (T==i16) 0 else cast_i{i8,x})
def {top, inctop} = topper{T, I, 8, x}
@for_special_buffered{r,8} (w in *u8~~w over i to sum) {
ind := load{itab, w}
pc := popc_alt{w, ind, 8}
v := unpackLo{j + I~~make{[2]u64, ind, 0}, tupsel{0,top}}
def st{k, v:V} = store{*V~~r, k, v}
if (T==i16) st{0, v}
else each{st, iota{2}, unpack{v, tupsel{1,top}}}
r += pc
j += I**8
inctop{i, top}
}
}
# i8 & i16 w/x; 128 bits/iter; [16]i8 shuffle
def thresh{c==1, T==i8 & hasarch{'SSSE3'}} = 64
def thresh{c==1, T==i16 & hasarch{'SSSE3'}} = 32
fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'}}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def V = [16]i8
@for_special_buffered{r,8} (w in *u8~~wp over i to sum) {
ind := load{itab, w}
pc := popc_alt{w, ind, 8}
s := V~~make{[2]u64, ind,0}
if (T==i16) { s+=s; s = V~~unpackLo{s, s+V**1} }
res := sel{V, load{*V~~(x+8*i)}, s}
if (T==i8) store{*u64~~r, 0, extract{[2]u64~~res, 0}}
else store{*V~~r, 0, res}
r+= pc r+= pc
} }
} }
def tab{n,l} = { # i32 w/x; 8 elts/iter into 2 steps; [16]i8 shuffle
def m=n-1; def t=tab{m,l} i32tab:*u32 = maketab{4,8,2} # 16 elts, 64B
def k = (1<<l - 1) << (m*l) def thresh{c==1, T==i32 & hasarch{'SSSE3'}} = 8
merge{t, k+t} fn slash{c==1, T==i32 & hasarch{'SSSE3'}}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
} def V = [16]i8
def tab{n==0,l} = tup{0} expander := make{V, iota{16}>>2}
c16lut:*u64 = tab{4,16} trail := make{V, tail{2,iota{16}}}
def step{w,i} = {
def vgLoad{p:T, i & T == *u64} = emit{eltype{T}, 'vg_loadLUT64', p, i} ind := load{i32tab, w}
pc := popc_alt{w, ind, 6}
def comp16{w:*u64, X, r:*i16, l:u64} = { s := sel{[16]i8, V~~make{[4]u32, ind, ... 3**0}, expander} | trail
@for(w in *u8~~w over i to cdiv{l,8}) { res := sel{V, load{*V~~(x+4*i)}, s}
def step{w} = { store{*V~~r, 0, res}
pc:= popcRand{w} r+= pc
storeu{*u64~~r, 0, pext{promote{u64,X{}}, vgLoad{c16lut, w}}} }
r+= pc @for_special_buffered{r,8} (w in *u8~~wp over i to sum) {
} def rn = if (has_popc) r+popc{w} else 0
step{w&15} step{w&0xf, 2*i}
step{w>>4} # this runs even if the above step was all that's required, so it'll act on the invalid result of "r+= pc", so we need to overallocate even more to compensate step{w>>4, 2*i+1}
if (has_popc) r = rn
} }
} }
fn slash2{F, T}(w:*u64, x:*T, r:*T, l:u64) : void = { # i32 & i64 w/x & x+/w; 256 bits/step, 8 elts/iter; [8]i32 shuffle
xv:= *u64~~x i64tab:*u64 = maketab{4,16,1,{x}=>(1+x)*0x100 + x} # 16 elts, 128B
F{w, {} => {c:= loadu{xv}; xv+= 1; c}, r, l} def thresh{c, T==i32 & hasarch{'AVX2'}} = 32
def thresh{c, T==i64 & hasarch{'AVX2'}} = 8
fn slash{c, T & hasarch{'AVX2'} & T>=i32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def V = [8]u32
expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}}
def from_ind = if (c) {
i:u64 = 0
{j} => { v:=load{*V~~x, i}; ++i; sel{V, v, j} }
} else if (T==i32) {
def VT = [8]T
i := VT**x
ii := VT**8
{j} => { v:=i+VT~~j; i+=ii; v }
}
def tab = if (tw==32) itab else i64tab
def step{r, w} = {
s:= loadBatch{*u8~~(tab+w), 0, V}
store{*V~~r, 0, from_ind{s}}
}
@for_special_buffered{r,8} (w in *u8~~wp to sum) {
pc := popc{w}
if (tw==32) {
step{r,w}
} else {
h := w&0xf
step{r, h}
step{r+popcRand{h}, w>>4}
}
r += pc
}
} }
fn slash1{F, T, iota, add}(w:*u64, r:*T, l:u64) : void = { # everything; 512 bits/iter; AVX-512 compress
x:u64 = iota def thresh{c, T==i8 & hasarch{'AVX512VBMI2'}} = 256
F{w, {} => {c:= x; x+= add; c}, r, l} def thresh{c, T==i16 & hasarch{'AVX512VBMI2'}} = 128
def thresh{c, T==i32 & hasarch{'AVX512F'}} = 64
def thresh{c, T==i64 & hasarch{'AVX512F'}} = 16
fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def f = fmtnat
def wt = width{T}
def vl = 512/wt
def V = [vl]T
def X = getter{c, V, x}
def wu = max{32,vl}
@for (w in *(ty_u{vl})~~w over cdiv{l,vl}) {
def I = ty_u{wu}
def emitT{O, name, ...a} = emit{O, merge{'_mm512_',name,'_epi',f{wt}}, ...a}
def to_mask{a} = emit{[vl]u1, merge{'_cvtu',f{wu},'_mask',f{vl}}, a}
m := to_mask{promote{I,w}}
c := popc{w}
x := X{}
# The compress-store instruction performs very poorly on Zen4,
# and is also a lot worse than the following on Tiger Lake
# emitT{void, 'mask_compressstoreu', r, m, x}
cs := cast_i{I,promote{i64,1}<<(c%64) - 1}
if (wu==64) cs -= cast_i{I,c}>>6
v := emitT{V, 'mask_compress', x, m, x}
emitT{void, 'mask_storeu', r, to_mask{cs}, v}
r += c
}
} }
# 8-bit writes ~8 bytes of garbage past end, 16-bit writes ~16 bytes export{'si_1slash8' , slash{0, i8 }}
export{'bmipopc_2slash8', slash2{comp8, i8}} export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh{0, i16}}
export{'bmipopc_2slash16', slash2{comp16, i16}} export{'si_1slash32', slash{0, i32}}; export{'si_thresh_1slash32', u64~~thresh{0, i32}}
export{'bmipopc_1slash8', slash1{comp8, i8, 0x0706050403020100, 0x0808080808080808}} export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh{1, i8 }}
export{'bmipopc_1slash16', slash1{comp16, i16, 0x0003000200010000, 0x0004000400040004}} export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh{1, i16}}
export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh{1, i32}}
export{'si_2slash64', slash{1, i64}}; export{'si_thresh_2slash64', u64~~thresh{1, i64}}
# pext, or boolean compress
fn pext{T}(x:T, m:T) {
def w = width{T}
def mod{a} = a % (1<<w)
def lowbits{k} = base{1<<k, cdiv{w,k}**1}
# At each step, x and z are split into groups of length k
# - z tells how many bits in the group are NOT used
# - x contains the bits, with z zeros above
def build{k==1} = tup{x&m, ~m}
def build{k & k > 1} = {
def h = k>>1 # Increase size from h to k
{x,z} := build{h}
def low = lowbits{k} # Low bit in each new group
if (k <= 3) {
z0 := z & low
zm := z>>1 & low
if (k == 2) tup{
x - (x>>1 & z0),
z0 + zm
} else tup{ # Faster 1->3 jump, currently unused
x - ((x>>1&mod{low*3}) & (z|z0<<1)) - (x>>2 & (z & zm)),
(z0 + zm) + (z>>2 & low)
}
} else {
# Shift high x group down by low z, then add halves of z
even:T = mod{low*(1<<h - 1)}
# SWAR shifter: shift x by sh*o, in length-k groups
def shift{sh, o, x} = {
l := o & low; m := l<<k - l
s := (x & m)>>sh | (x &~ m)
if (2*sh<=k/2) shift{2*sh, o>>1, s} else s
}
tup{
(x&even) | shift{1, z, x&~even},
if (k>4) (z + z>>h)&even else ((z&~even)>>h) + (z&even)
}
}
}
# Finally, compose groups with regular shifts
def g = 8 # 12 performs about the same
{b,z} := build{g}
o := z*lowbits{g} # Offsets by prefix sum
def s = 1<<g - 1
def gr{sh} = (b & mod{s<<sh}) >> (o>>(sh-g) & s)
fold{|, b&s, each{gr, g*slice{iota{cdiv{w,g}},1}}}
}
fn pext{T & hasarch{'PCLMUL'} & T==u64}(xs:T, ms:T) {
def num = lb{width{T}}
def vec{s} = make{[2]T, s, 0}
m := vec{ms}
x := vec{xs} & m
d := ~m << 1 # One bit of the position difference at x
c := vec{1<<64-1}
@unroll (i to num) {
def sh = 1 << i
def shift_at{v, s} = { v = (v&~s) | (v&s)>>sh }
p := clmul{d, c, 0} # xor-scan
d = d &~ p # Remove even bits
p &= m
shift_at{m, p}
shift_at{x, p}
}
extract{x, 0}
}
fn pext{T & hasarch{'BMI2'}}(x:T, m:T) = pext{x, m}
export{'si_pext_u64', pext{u64}}

View File

@ -47,8 +47,8 @@ def topBlend{f:T, t:T, m:M & w128{T} & w128i{M,32}} = T ~~ emit{[4]f32, '_mm_ble
def topBlend{f:T, t:T, m:M & w128{T} & w128i{M,64}} = T ~~ emit{[2]f64, '_mm_blendv_pd', v2d{f}, v2d{t}, v2d{m}} def topBlend{f:T, t:T, m:M & w128{T} & w128i{M,64}} = T ~~ emit{[2]f64, '_mm_blendv_pd', v2d{f}, v2d{t}, v2d{m}}
def topBlend{f:T, t:T, m:M & w128{T} & w128i{M, 8}} = T ~~ emit{[16]i8, '_mm_blendv_epi8', v2i{f}, v2i{t}, v2i{m}} def topBlend{f:T, t:T, m:M & w128{T} & w128i{M, 8}} = T ~~ emit{[16]i8, '_mm_blendv_epi8', v2i{f}, v2i{t}, v2i{m}}
# assumes all bits are the same in each mask item # assumes all bits are the same in each mask item
def homBlend{f:T, t:T, m:M & w128{T} & w128{M} & elwidth{M}!=16} = topBlend{f, t, m} def homBlend{f:T, t:T, m:M & hasarch{'SSE4.1'} & w128{T} & w128{M} & elwidth{M}!=16} = topBlend{f, t, m}
def homBlend{f:T, t:T, m:M & w128{T} & w128{M,16}} = topBlend{f, t, [16]i8~~m} def homBlend{f:T, t:T, m:M & hasarch{'SSE4.1'} & w128{T} & w128{M,16}} = topBlend{f, t, [16]i8~~m}