fast inds⊸⊏˘bits for ≤8-bit input & output cells

This commit is contained in:
dzaima 2024-08-08 23:28:43 +03:00
parent 920a89f019
commit 5748833060
6 changed files with 172 additions and 13 deletions

View File

@ -711,7 +711,11 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr
case n_rtack: dec(w); return x;
case n_ltack: return const_cells(x, xk, xsh, w, chr);
case n_select:
if (isArr(w) && RNK(w)==1 && xcr==1 && TI(w,arrD1)) { // TODO handle RNK(w)!=1
if (isArr(w) && RNK(w)==1 && xcr==1) { // TODO handle RNK(w)!=1
if (!TI(w,arrD1)) {
w = num_squeezeChk(w);
if (!TI(w,arrD1)) break;
}
assert(xr > 1);
ux wia = IA(w);
ShArr* rsh = m_shArr(xr);

View File

@ -559,16 +559,19 @@ B select_replace(u32 chr, B w, B x, B rep, usz wia, usz cam, usz csz) { // consu
#undef FREE_CHECK
}
static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same element type as src, but new ia
u8 se = TI(src,elType); assert(se!=el_bit);
static void* m_arrv_same_t(B* r, usz ia, u8 ty) {
u8 se = TIi(ty,elType);
if (se==el_B) {
HArr_p p = m_harr0v(ia);
*r = p.b;
return p.a;
} else {
return m_tyarrlv(r, arrTypeWidthLog(TY(src)), ia, arrNewType(TY(src)));
return m_tyarrlbv(r, arrTypeBitsLog(ty), ia, arrNewType(ty));
}
}
static void* m_arrv_same(B* r, usz ia, B src) { // makes a new array with same element type as src, but new ia
return m_arrv_same_t(r, ia, TY(src));
}
B slash_c2(B, B, B);
Arr* customizeShape(B x); // from cells.c
@ -585,7 +588,7 @@ B select_cells_base(B inds, B x0, ux csz, ux cam);
#endif
#define INDS_BUF_MAX 64 // only need 32 bytes for AVX2 & 16 for NEON, but have more for past-the-end pointers and writes
B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊x
B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (indn↑inds As ie)⊸⊏˘ cam‿csz⥊x; if inds are valid and csz<=128, ie must be <=el_i8
assert(csz!=0 && cam!=0 && indn!=0);
assert(csz*cam == IA(x));
assert(ie<=el_i32);
@ -597,30 +600,38 @@ B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (
if (!getRange_fns[ie](inds, bounds, indn) || bounds[0]<-1 || bounds[1]>0) goto generic_any;
return C2(slash, m_f64(indn), taga(arr_shVec(customizeShape(x))));
}
assert(csz>=2);
ux ria = indn * cam;
B r;
u8* xp;
u8 xe = TI(x,elType);
u8 lb = arrTypeWidthLog(TY(x));
u8* xp;
if (xe==el_B) {
if (sizeof(B) != 8) goto generic_any;
xp = (u8*) arr_bptr(x);
if (xp == NULL) goto generic_any;
} else {
if (xe == el_bit) goto generic_any;
xp = tyany_ptr(x);
if (xe == el_bit) {
#if SINGELI_AVX2 || SINGELI_NEON
if (indn<=8 && csz<=8) goto bit_ok;
#endif
goto generic_any;
goto bit_ok; bit_ok:;
}
}
B r;
ux ria = indn * cam;
bool fast; (void) fast;
ux xbump = csz<<lb;
ux rbump = indn<<lb;
i64 bounds[2];
if (ie==el_bit) {
// TODO path for xe==el_bit + long indn
if (csz>32 || indn>32 || indn>INDS_BUF_MAX) { // TODO properly tune
assert(xe!=el_bit && (csz>8 || indn>8));
u8* rp = m_arrv_same(&r, ria, x);
for (ux i = 0; i < cam; i++) {
bitselFns[lb](rp, inds, loadu_u64(xp), loadu_u64(xp + (1<<lb)), indn);
@ -663,6 +674,7 @@ B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (
}
}
skip_bounds_check:;
assert(ie==el_i8 || csz>128);
#if SINGELI_AVX2 || SINGELI_NEON
if (fast) {
@ -680,11 +692,68 @@ B select_rows_direct(B x, ux csz, ux cam, void* inds, ux indn, u8 ie) { // ⥊ (
}
#endif
#if SINGELI_AVX2 || SINGELI_NEON
if (xe==el_bit) {
assert(ie==el_i8 && csz<=8 && indn<=8 && csz>=2 && indn>=1);
// TODO si_select_cells_bit_lt64 for indn==1
static const u8 rep_lut[9] = {0,3,2,1,1,0,0,0,0};
u8 exp = rep_lut[csz>indn? csz : indn];
ux rindn = indn<<exp;
ux rcsz = csz<<exp;
assert(rcsz<=8 && rindn<=8);
ux rcam = (cam + (1<<exp)-1)>>exp;
if (rcsz!=8) {
Arr* xa = customizeShape(x);
usz* xsh = arr_shAlloc(xa, 2);
xsh[0] = rcam;
xsh[1] = rcsz;
// leave ia unchanged, desynchronizing from product of shape; TODO really really shouldn't do that, and instead pass cam & csz directly to bit widener
x = widenBitArr(taga(xa), 1);
xp = tyany_ptr(x);
SELECT_ROWS_PRINTF("8bit: widen %zu‿%zu → ⟨%zu,%zu→8⟩\n", cam, csz, rcam, rcsz);
}
if (exp!=0) {
simd_repeat_inds(inds, inds_buf, indn, csz);
inds = inds_buf;
}
u64* rp;
ux ria0 = rindn!=8? 8*rcam : ria;
r = m_bitarrv(&rp, ria0);
SELECT_ROWS_PRINTF("8bit: indn=%zu rindn=%zu csz=%zu rcsz=%zu cam=%zu ria0=%zu rcam=%zu\n", indn, rindn, csz, rcsz, cam, ria0, rcam);
si_select_rows_8bit(inds, rindn, xp, rp, (ria0+7)/8);
if (rindn!=8) {
SELECT_ROWS_PRINTF("8bit: narrow %zu → %zu\n", rcsz, csz);
usz* rsh = arr_shAlloc(a(r), 2);
rsh[0] = rcam;
rsh[1] = 8;
usz tgt = rindn;
r = narrowWidenedBitArr(r, 1, 1, &tgt); // TODO this assumes trailing zeroes
Arr* ra = arr_shVec(customizeShape(r));
r = taga(ra);
ux ria1 = ra->ia;
assert(ria <= ria1);
FINISH_OVERALLOC(ra, offsetof(TyArr,a) + (ria+7)/8, offsetof(TyArr,a) + (ria1+7)/8);
ra->ia = ria;
}
goto decG_ret;
}
#endif
u8* rp = m_arrv_same(&r, ria, x);
ux slow_cam = cam;
#if SINGELI_AVX2 || SINGELI_NEON
ux lnt = CLZC(csz-1); // ceil-log2 of number of elements in table
if (fast && lnt < select_rows_tab_h) {
u8 max_indn = select_rows_max_indn[lb];
if (indn > max_indn) goto no_fast;
@ -788,7 +857,7 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸
ux in = IA(inds);
if (in == 0) return taga(emptyArr(x, 1));
u8 ie = TI(inds,elType);
if (csz<=2? ie!=el_bit : csz<128? ie>el_i8 : !elInt(ie)) {
if (csz<=2? ie!=el_bit : csz<=128? ie>el_i8 : !elInt(ie)) {
inds = num_squeeze(inds);
ie = TI(inds,elType);
if (!elInt(ie)) goto generic;

View File

@ -1,7 +1,10 @@
def __shl{(u16)}{a:T, b} = T~~(re_el{u16,a}<<b) # for x86's lack of u8 shift
def __shr{(u16)}{a:T, b} = T~~(re_el{u16,a}>>b)
def broadcast{[(n*2)]E, x:[n]E} = pair{x, x}
def pow2_up{v, least} = max{least, 1<<ceil_log2{v}} # least ⌈ ⌈⌾(2⊸⋆⁼) v
def has_sel = hasarch{'AVX2'} or hasarch{'AARCH64'}
# make a LUT of at least nt elements in tab, to be indexed by [ni_real≥ni]u8
# E must be unsigned
# mode is a hint on expected usage:
@ -173,6 +176,7 @@ def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=16*4 and ni<=16} =
def lut = each{TG{[16]u8, .}, range{vn}}
{is:([16]u8)} => tup{sel{lut, is}}
}}}}
# TODO E==u8 nt==128 via tbx4(tbl4,d,i-64)
def lut_gen{mode, E, nt, ni if mode=='i' and hasarch{'AVX2'} and (E==u16 or E==u64)} = zip_halves{mode, E, nt, ni}
# def lut_gen{mode, E, nt, ni if mode=='c' and hasarch{'AVX2'} and E==u16} = widen_inds{mode, E, nt, max{ni,16}, 2}
@ -187,3 +191,30 @@ def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and E==u64} = widen_inds{mode,
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='c' and E>=u16} = 0
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='i' and E==u64 and nt>16} = 0
def lut16{tab:([16]u8), idxs:([16]u8)} = sel{[16]u8, tab, idxs}
def lut16{tab:([16]u8), idxs:([32]u8) if hasarch{'X86_64'}} = sel{[16]u8, pair{tab, tab}, idxs}
def shuf_u8bits{inds:*u8, ni} = 0
def shuf_u8bits{inds:*u8, ni if has_sel} = {
def I = [16]u8
v0:= I**0
v1:= I**0
iv:= I**1
@for (ind in inds over ni) {
c:= (iota{I} & I**(1<<(ind&3))) != I**0
c&= iv
iv = iv+iv
if (ind>=4) v1|= c
else v0|= c
}
{x:V=[_](u8)} => {
def __shr{a:(V), 4 if hasarch{'X86_64'}} = (a >>{u16} 4) & V**0x0f
def lo = lut16{v0, x & V**0x0f}
def hi = lut16{v1, x >> 4}
lo | hi
}
}

View File

@ -9,8 +9,6 @@ def arch_minvw = if (hasarch{'AARCH64'}) 64 else 128
def arch_minv{T=[_]E if width{T}< arch_minvw} = [arch_minvw / width{E}]E
def arch_minv{T if width{T}>=arch_minvw} = T
def has_sel = hasarch{'AVX2'} or hasarch{'AARCH64'}
def gather
if_inline (hasarch{'AVX2'}) {
# def:T - masked original content
@ -306,6 +304,16 @@ exportT{'si_select_tab', join{table{select_fn,
1
}
export{'simd_select_bool128', simd_select_bool128}
fn si_select_rows_8bit(inds:*u8, indn:ux, src:*void, dst:*void, rows:ux) : void = { # leaves zeroes in result cells above indn
def bulk = arch_defvw / 8
def V = [bulk]u8
def lut = shuf_u8bits{inds, indn}
@maskedLoop{bulk}(src in tup{V,*u8~~src}, dst in tup{V,*u8~~dst} over rows) {
dst = lut{src}
}
}
export{'si_select_rows_8bit', si_select_rows_8bit}
})

View File

@ -94,6 +94,14 @@
!"⊏: Indexing out-of-bounds (¯129∊𝕨, 128≡≠𝕩)" % %USE evar 10¯129 {𝕨˘𝕩}_evar 101281
!"⊏: Indexing out-of-bounds (128∊𝕨, 128≡≠𝕩)" % %USE evar 10128 {𝕨˘𝕩}_evar 101281
!"⊏: Indexing out-of-bounds (1∊𝕨, 1≡≠𝕩)" % %USE evar 1001 {𝕨˘𝕩}_evar 1011
!"⊏: Indexing out-of-bounds (1000∊𝕨, 3≡≠𝕩)" % %USE evar (31000) {𝕨˘𝕩}_evar 10031
!"⊏: Indexing out-of-bounds (1000∊𝕨, 4≡≠𝕩)" % %USE evar (41000) {𝕨˘𝕩}_evar 10041
!"⊏: Indexing out-of-bounds (1000∊𝕨, 8≡≠𝕩)" % %USE evar (81000) {𝕨˘𝕩}_evar 10081
12˘ 108100 %% (8×10) + 12
12˘ 1042100 %% (8×10) + [23,45]
12,1,0˘ 1042100 %% (8×10) + [32,54]
[21,45]˘ 108100 %% (8×10) + [21,45]
[11,01]˘ 1022100 %% (4×10) + 22223230123
(
%USE IS_HEAPVERIFY

View File

@ -34,3 +34,42 @@
Test {{@ + 𝕩 •rand.Range 1114111}}
Test {{"foo"@¨ 𝕩 •rand.Range 2}}
)
(
%USE var
F {𝕨˘ 𝕩}
{
𝕩+ 1
is 𝕨 •rand.Range 𝕩
n •rand.Range 200
d n𝕩•rand.Range 2
isn𝕩 ! 1= is F d, is F "Ai8"V d {𝕩=2? ("Ai8"V is) F d; }𝕩
}˜ 10
)
(
%USE var
F {𝕨˘ 𝕩}'e'
{
𝕩+ 1
𝕨+ 1
is 𝕨𝕩
n 1+•rand.Range 200
d n𝕩•rand.Range 2
𝕨𝕩n! ´ 'e'¨ is F d, is F "Ai8"V d {𝕩=2? ("Ai8"V is) F d; }𝕩
}˜ 10
)
(
%USE var
F {𝕨˘ 𝕩}'e'
{𝕊 inds:
csz inds
inds {
csz, 𝕨, 𝕩 ! 1= (V𝕨¨ (´𝕨1), "Ab" "Ai8") F V𝕩¨ "Ab""Ai8""Ai32"
} {𝕩csz •rand.Range 2}¨ (8)01+329
}¨ 22 {𝕩, •rand.Range˜𝕩}¨ 13458916
)