aarch64 8→n bitnarrow

This commit is contained in:
dzaima 2024-08-14 02:21:13 +03:00
parent c72ed51149
commit eccbac37ab
3 changed files with 107 additions and 14 deletions

View File

@ -129,22 +129,79 @@ def bitalign{{2,8,s}, 8, G} = {
def maketabs{k, is, i, ...ts} = {
def makevtabs{k, is, ...ts} = {
tab:*u8 = join{each{{s} => join{
each{{{E, t}} => t{s, range{k}} & 0xff, ts}
each{{{E, t}} => {
if (kgen{t}) t{s, range{k}} & 0xff
else t
}, ts}
}, is}}
def ctab = length{ts}*i + *[k]u8~~tab
each{{j, {E,_}} => re_el{E,load{ctab,j}}, inds{ts}, ts}
{i} => {
def ctab = length{ts}*i + *[k]u8~~tab
each{{j, {E,_}} => re_el{E,load{ctab,j}}, inds{ts}, ts}
}
}
def __shl{a:([16]u8), sh:([16]i8) if hasarch{'AARCH64'}} = a << [16]u8~~sh
def bitalign{{2,8,s}, 8, G if hasarch{'AARCH64'}} = G{s, {a:V=([16]u8)} => {
def {shuf1, shift1, shuf2, shift2} = maketabs{16, xrange{2,8}, s-2,
tup{u8, {s, r} => r *s>>3}, tup{i8, {s, r} => - r *s%8},
tup{u8, {s, r} => (r+1)*s>>3}, tup{i8, {s, r} => s - (r+1)*s%8},
def {shuf1, shift1, shuf2, shift2} = makevtabs{16, xrange{2,8},
tup{u8, {s,r} => r *s>>3}, tup{i8, {s,r} => - r *s%8},
tup{u8, {s,r} => (r+1)*s>>3}, tup{i8, {s,r} => s - (r+1)*s%8},
}{s-2}
def r0 = sel{[16]u8, a, shuf1} << shift1
def r1 = sel{[16]u8, a, shuf2} << shift2
(r0 | r1) & V**cast_i{u8, tail{s}}
}}
oper // ({a,b}=>floor{a/b}) infix left 40
def bitalign{8, {2,8,d}, G if hasarch{'AARCH64'}} = {
def props = memoize{{d} => {
def indz = range{16*d/8}
def shuf0 = (indz*8 ) // d
def shufE = (indz*8+7) // d
def count = shufE - shuf0 + 1
assert{all{count <= 4}}
def shift0 = shuf0*d - indz*8
def much = count>2
def shuf1 = replicate{much, shuf0 + 2}
def shift1 = replicate{much, shift0 + 2*d}
assert{length{shuf0}+length{shuf1} <= 16}
tup{
shiftright{merge{shuf0, shuf1}, 16**0},
shiftright{merge{shift0, shift1}, 16**0},
shiftright{(length{indz} + scan{+,much}) * much - 1, 16 ** -1},
}
}}
def irange = xrange{2,8}
def {shuf0, shift0} = makevtabs{16, irange,
tup{u8, {d,r} => select{props{d},0}},
tup{i8, {d,r} => select{props{d},1}},
}{d-2}
def needs_blender = each{{c} => not all{-1==select{props{c},2}}, irange}
def reverse_scan{G, v} = reverse{scan{{a,b}=>G{b,a}, reverse{v}}}
def blender = makevtabs{16, replicate{reverse_scan{|, needs_blender}, irange},
tup{i8, {d,r} => select{props{d},2}},
}
def b = sel{[16]u8, a, shuf1} << shift1
def c = sel{[16]u8, a, shuf2} << shift2
(b | c) & V**cast_i{u8, tail{s}}
}}
def run{do_blend}{a:V=([16]u8)} = {
def shuf1 = shuf0 + V**1
def shift1 = shift0 + [16]i8**cast_i{i8,d}
def b = a & V**cast_i{u8, tail{d}}
def r0 = sel{[16]u8, b, shuf0} << shift0
def r1 = sel{[16]u8, b, shuf1} << shift1
def r01 = r0 | r1
if (do_blend) r01 | sel{[16]u8, r01, ...blender{d-2}}
else r01
}
def bit_lut{bits, idx} = ((u64~~base{2,bits} >> idx) & 1) != 0
if (bit_lut{merge{2**0, needs_blender}, d}) G{d, run{1}}
else G{d, run{0}}
}

View File

@ -1,3 +1,4 @@
include 'debug/printf'
include './base'
include './cbqnDefs'
include './f64'
@ -29,7 +30,6 @@ fn bitwiden_n_8(src:*void, dst:*void, csz:ux, cam:ux) : void = {
assert{(csz>1) & (csz<8)}
def bulk = arch_defvw / 8
def V = [bulk]u8
def rbytes = cdiv{csz*cam, 8}
bitalign{tup{2,8,csz}, 8, {s, align} => {
@maskedLoop{bulk}(dst in tup{V,*u8~~dst} over cam) {
dst = align{load{*V~~src}}
@ -37,4 +37,31 @@ fn bitwiden_n_8(src:*void, dst:*void, csz:ux, cam:ux) : void = {
}
}}
}
export{'si_bitwiden_n_8', bitwiden_n_8}
export{'si_bitwiden_n_8', bitwiden_n_8}
(if (hasarch{'AARCH64'}) {
fn bitnarrow_8_n(src:*void, dst:*void, csz:ux, cam:ux) : void = {
assert{cam>0}
assert{(csz>1) & (csz<8)}
def bulk = arch_defvw / 8
def V = [bulk]u8
dstC:= *u8~~dst
dstE:= *u8~~dst + cdiv{csz*cam, 8}
bitalign{8, tup{2,8,csz}, {s, align} => {
def get{} = align{load{*V~~src}}
def next{} = {
padd{u8, src, bulk}
padd{u8, dstC, bulk*s/8}
}
while (dstC+bulk < dstE) {
store{*V~~dstC, 0, get{}}
next{}
}
while (dstC < dstE) {
storeBatch{dstC, 0, get{}, maskAfter{dstE - dstC}}
next{}
}
}}
}
export{'si_bitnarrow_8_n', bitnarrow_8_n}
})

View File

@ -225,6 +225,7 @@ B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bi
usz* rsh = arr_shAlloc(r, axis+cr);
shcpy(rsh, SH(x), axis);
shcpy(rsh+axis, csh, cr);
if (PIA(r)==0) goto decG_ret;
u8* xp = tyany_ptr(x);
// FILL_TO(rp, el_bit, 0, m_f64(1), PIA(r));
@ -251,7 +252,14 @@ B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bi
else for (ux i=0; i<cam; i++) ab_add(&ab, ((u64*)xp)[i], ocsz);
#else
switch(xcsz) { default: UD;
case 8: for (ux i=0; i<cam; i++) ab_add(&ab, ((u8* )xp)[i], ocsz); break; // all assume zero padding
case 8:
#if SINGELI_NEON
if (xcsz==8 && ocsz!=1) {
si_bitnarrow_8_n(xp, rp, ocsz, cam);
goto decG_ret;
}
#endif
for (ux i=0; i<cam; i++) ab_add(&ab, ((u8* )xp)[i], ocsz); break; // all assume zero padding
case 16: for (ux i=0; i<cam; i++) ab_add(&ab, ((u16*)xp)[i], ocsz); break;
case 32: for (ux i=0; i<cam; i++) ab_add(&ab, ((u32*)xp)[i], ocsz); break;
case 64: for (ux i=0; i<cam; i++) ab_add(&ab, ((u64*)xp)[i], ocsz); break;
@ -269,6 +277,7 @@ B narrowWidenedBitArr(B x, ur axis, ur cr, usz* csh) { // for now assumes the bi
}
}
ab_done(ab);
decG_ret:;
decG(x);
return taga(r);
}