aarch64 bool128 select

This commit is contained in:
dzaima 2024-07-18 21:44:34 +03:00
parent 0d5e77766a
commit 2ca488dd66
3 changed files with 50 additions and 33 deletions

View File

@ -41,7 +41,7 @@
#include "../utils/mut.h"
#include "../utils/calls.h"
#if SINGELI_AVX2
#if SINGELI
#define SINGELI_FILE select
#include "../utils/includeSingeli.h"
#endif
@ -185,12 +185,7 @@ B select_c2(B t, B w, B x) {
#if SINGELI_AVX2
#define CPUSEL(W, NEXT) /*assumes 3≤xl≤6*/ \
if (RARE(!avx2_select_tab[4*(we-el_i8)+xl-3](wp, xp, rp, wia, xn))) select_properError(w, x);
bool bool_use_simd = we==el_i8 && xl==0 && xia<=128;
#define BOOL_SPECIAL(W) \
if (sizeof(W)==1 && bool_use_simd) { \
if (RARE(!avx2_select_bool128(wp, xp, rp, wia, xn))) select_properError(w, x); \
goto setsh; \
}
#else
#define CASE(S, E) case S: for (usz i=i0; i<i1; i++) ((E*)rp)[i] = ((E*)xp+off)[ip[i]]; break
#define CASEW(S, E) case S: for (usz i=0; i<wia; i++) ((E*)rp)[i] = ((E*)xp)[WRAP(wp[i], xn, thrF("⊏: Indexing out-of-bounds (%i∊𝕨, %s≡≠𝕩)", wp[i], xn))]; break
@ -214,6 +209,17 @@ B select_c2(B t, B w, B x) {
} \
if (wt) TFREE(wt); \
}
#endif
#if SINGELI_SIMD
bool bool_use_simd = we==el_i8 && xl==0 && xia<=128;
#define BOOL_SPECIAL(W) \
if (sizeof(W)==1 && bool_use_simd) { \
if (RARE(!simd_select_bool128(wp, xp, rp, wia, xn))) select_properError(w, x); \
goto setsh; \
}
#else
bool bool_use_simd = 0;
#define BOOL_SPECIAL(W)
#endif

View File

@ -164,7 +164,7 @@ def {
absu,andAllZero,andnz,b_getBatch,blend,clmul,cvt,extract,fold_addw,half,
homAll,homAny,bitAll,bitAny,homBlend,homMask,homMaskStore,homMaskStoreF,loadBatchBit,
loadLow,make,maskStore,maskToHom,mulw,mulh,narrow,narrowTrunc,narrowPair,packQ,pair,pdep,
pext,popcRand,sel,shl,shr,shuf,shuf16Hi,shuf16Lo,shufHalves,shufInd,storeLow,
pext,popcRand,rbit,sel,shl,shr,shuf,shuf16Hi,shuf16Lo,shufHalves,shufInd,storeLow,
topBlend,topMask,topMaskStore,topMaskStoreF,unord,unzip,vfold,widen,widenUpper,
}

View File

@ -17,14 +17,16 @@ if_inline (hasarch{'AVX2'}) {
if (M{0}) T ~~ emit{[4]i64, '_mm256_mask_i32gather_epi64', d, *void~~b, idx, M{T,'to sign bits'}, elwidth{B}/8}
else T ~~ emit{[4]i64, '_mm256_i32gather_epi64', *void~~b, idx, elwidth{B}/8}
}
def wrapChk{cw0, VI,xlf, M} = {
cw:= cw0 + (xlf & VI~~(cw0<VI**0))
if (homAny{M{ty_u{cw} >= ty_u{xlf}}}) return{0}
cw
}
}
def wrapChk{cw0, VI,xlf, M} = {
cw:= cw0 + (xlf & VI~~(cw0<VI**0))
if (homAny{M{ty_u{cw} >= ty_u{xlf}}}) return{0}
cw
}
def wrapChk{cw0:VI, xlf, M} = wrapChk{cw0, VI,xlf, M}
if_inline (hasarch{'AVX2'}) {
@ -141,32 +143,41 @@ exportT{'avx2_select_tab', join{table{select_fn,
}
if_inline(hasarch{'AVX2'}) {
fn avx2_select_bool128(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
if_inline(hasarch{'AVX2'} or hasarch{'AARCH64'}) {
fn simd_select_bool128(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
def TI = i8
def VI = [32]TI
w:= *VI ~~ w0
r:= *u32 ~~ r0
def VI = [arch_defvw/8]TI
def VU = ty_u{VI}
w:= *VI ~~ w0
xlf:= VI**cast_i{TI, xl}
if (wl>32 and xl<=16) {
xb:= shuf{[4]u64, spreadBits{[32]u8, load{*u32~~x0}}, 4b1010}
@maskedLoop{32}(cw0 in w, sr in r, M in 'm' over wl) {
cw:= wrapChk{cw0, VI,xlf, M}
sr = homMask{sel{[16]i8, xb, cw}}
if (hasarch{'AARCH64'}) {
def xrev = rbit{load{*VU ~~ x0}}
@maskedLoop{16}(cw0 in w, r in *u16~~r0, M in 'm' over i to wl) {
def cw = ty_u{wrapChk{cw0, xlf, M}}
def byte = sel{[16]u8, xrev, cw>>3}
r = homMask{ty_s{byte << (cw & VU**7)} < VI**0}
}
} else {
x:= shuf{[4]u64, load{*VI ~~ x0}, 4b1010}
low:= VI**7
b := VI~~make{[32]u8, 1 << (iota{32} & 7)}
@maskedLoop{32}(cw0 in w, sr in r, M in 'm' over wl) {
cw:= wrapChk{cw0, VI,xlf, M}
byte:= sel{[16]i8, x, VI~~(([8]u32~~(cw&~low))>>3)}
mask:= sel{[16]i8, b, cw & low}
sr = homMask{(mask & byte) == mask}
if (wl>32 and xl<=16) {
xb:= shuf{[4]u64, spreadBits{[32]u8, load{*u32~~x0}}, 4b1010}
@maskedLoop{32}(cw0 in w, sr in *u32~~r0, M in 'm' over wl) {
cw:= wrapChk{cw0, xlf, M}
sr = homMask{sel{[16]i8, xb, cw}}
}
} else {
x:= shuf{[4]u64, load{*VI ~~ x0}, 4b1010}
low:= VI**7
b := VI~~make{[32]u8, 1 << (iota{32} & 7)}
@maskedLoop{32}(cw0 in w, sr in *u32~~r0, M in 'm' over wl) {
cw:= wrapChk{cw0, xlf, M}
byte:= sel{[16]i8, x, VI~~(([8]u32~~(cw&~low))>>3)}
mask:= sel{[16]i8, b, cw & low}
sr = homMask{(mask & byte) == mask}
}
}
}
1
}
export{'avx2_select_bool128', avx2_select_bool128}
export{'simd_select_bool128', simd_select_bool128}
}