Singeli n→8 bitwiden

This commit is contained in:
dzaima 2024-08-14 02:35:31 +03:00
parent f36cefc9ac
commit c72ed51149
3 changed files with 176 additions and 0 deletions

View File

@ -0,0 +1,150 @@
local include 'util/tup'
def xrange{s, e} = s + range{e-s}
def shuf_imm = shuf
def shuf_impl{rw, ...iw, data} = assert{0, 'shuffling failed', ...data{'info'}}
def type0{vs} = type{select{vs,0}}
# new_shuf{v0:[k]E, v1:[k]E, ..., indices} : [length{indices}]E; -1 for zero, -2 for arbitrary
def new_shuf{...vs0 if isvec{try_same_type{vs0,'!'}}, {...is}} = { # : [length{is}]E
def vs = each{ty_u, vs0}
def vn = length{vs}
def S = type0{vs}
def E = eltype{S}
def ni = length{is}
def data0{'info'} = tup{...vs, is}
assert{all{(is >= -2) & (is < vn*vcount{S})}, 'bad shuffle indices', ...vs, is}
def widen_inds{sc, is} = flat_table{+, sc*is, range{sc}}
def data0{(E)} = tup{...vs, is}
def data0{N if istype{N} and isunsigned{N} and N<E} = {
def sc = width{E} / width{N}
tup{...each{re_el{N,.}, vs}, widen_inds{sc, is}}
}
def data0{W if istype{W} and isunsigned{W} and W>E} = {
def sc = width{W} / width{E}
if (ni%sc == 0) {
def is2 = (select{is, sc*range{ni/sc}} / sc) >> 0
if (same{is, widen_inds{sc, is2}}) tup{...each{re_el{W,.}, vs}, is2}
else '!'
} else '!'
}
def data0{'nz', T} = match(data{T}) {
{r={_,is} if all{is != -1}} => r
{_} => '!'
}
def data0{'h', ...a} = not same{data{...a},'!'}
def data = memoize{data0}
re_el{E, shuf_impl{width{E}*ni, ...vn**width{S}, data}}
}
def new_shuf{w if istype{w} or knum{w}, ...vs if isvec{try_same_type{vs,'!'}} and (not isvec{w} or width{w} == width{type0{vs}}), {...is}} = {
def S = type0{vs}
def D = if (isvec{w}) w else re_el{if (isprim{w}) w else primtype{'u',w}, S}
S~~new_shuf{...each{{c} => reinterpret{D, c}, vs}, is}
}
local def tern{{...c}, t, f} = eachx{{c,t,f} => if (c) t else f, c, t, f}
local def perm_split{x:X, lanes, chunks} = {
def b = 16/chunks
def gr = lanes >> lb{b}
def u = each{{ok,g} => unique{replicate{ok,g}}, lanes>=0, gr}
if (all{{c} => length{c}<=chunks, u}) {
def si0 = join{each{index_of, u, gr}*b + lanes%b}
def si = tern{join{lanes>=0}, si0, 0xff}
# def si = join{each{{u,g,l} => l>=0) index_of{u,g}*b + lanes%b else 0xff, u, gr, lanes}}
sel{[16]u8,
new_shuf{re_el{ty_u{b*8},X}, x, join{each{shiftright{.,chunks**0}, u}}},
make{X, si}
}
} else '!'
}
local def perm_shufb{x, is} = {
def lanes = split{16, is}
def r1 = perm_split{x, lanes, 2}
if (same{r1,'!'}) perm_split{x, lanes, 4}
else r1
}
# TODO accept narrower inputs
def shuf_impl{256, 256, data if hasarch{'AVX2'} and not same{perm_shufb{...data{u8}},'!'}} = perm_shufb{...data{u8}}
def shuf_impl{256, 128, data if hasarch{'AVX2'} and data{'h',u8}} = { def {x,is} = data{u8}; sel{[16]u8, pair{x, x}, make{[32]i8, is}} }
def shuf_impl{256, 256, data if hasarch{'AVX2'} and data{'h','nz',u32}} = { def {x,is} = data{'nz',u32}; sel{[8]u32, x, make{[8]u32, is}} }
def shuf_impl{256, 256, data if hasarch{'AVX2'} and data{'h','nz',u64}} = { def {x,is} = data{'nz',u64}; shuf_imm{[4]u64, x, base{4,is}} }
def in_chunks{c, is} = { def ls = split{c, is}; all{{l, s} => all{(l<0) | ((l>=s) & (l<s+c))}, ls, inds{ls}*c} }
def shuf_impl{256, 256, data if hasarch{'AVX2'} and in_chunks{16,select{data{u8},1}}} = { def {x,is} = data{u8}; sel{[16]u8, x, make{[32]u8, is}} }
def shuf_impl{128, 128, data if hasarch{'SSSE3'}} = { def {x,is} = data{u8}; sel{[16]u8, x, make{[16]i8, is}} }
def shuf_impl{128, 128, data if hasarch{'AARCH64'}} = { def {x,is} = data{u8}; sel{[16]u8, x, make{[16]i8, is}} }
def __shl{(u64)}{a:T, b} = T~~(re_el{u64,a}<<b)
def __shr{(u64)}{a:T, b} = T~~(re_el{u64,a}>>b)
def bitalign{s, 8 if s<7, a:V=[k](u8) if hasarch{'X86_64'}} = {
def V16 = re_el{u16,V}
def s0 = (range{k/2}*2*s) >> 3
def b = new_shuf{a, flat_table{+, s0, tup{0,1}}}
def c = V16~~b * make{V16, 1<<(8*s0 + 16 - s*(2+2*range{k/2}))}
def d = (c >> (8 -s )) & V16**(tail{s}<<8)
def e = (c >> (16-s*2)) & V16**tail{s}
re_el{u8, d|e}
# def d = re_el{u8, c >> (8 -s)}
# def e = re_el{u8, c >> (16-s*2)}
# homBlend{e, d, make{V, cycle{k, tup{0, 0xff}}}} & V**tail{s}
}
def bitalign{7, 8, a:V=[k](u8) if hasarch{'X86_64'}} = {
def V16 = re_el{u16,V}
def b = new_shuf{a, range{k} - (((range{k}+2)/8)>>0)}
def c = blend{[8]u16, b, b <<{u64} 4, 2b01100110}
def d = (V16~~c * make{V16, cycle{k/2, 1<<tup{2,0,2,0}}}) >> 2 # TODO ofence constant for clang
homBlend{V~~d, V~~(d+d), make{V, cycle{k, tup{0, 0xff}}}} & V**0x7f
}
def switchall{selected, options, G} = {
def end = makelabel{}
each{{option} => {
if (selected == option) {
G{option}
goto{end}
}
}, options}
unreachable{}
setlabel{end}
}
def bitalign{{2,8,s}, 8, G} = {
switchall{s, xrange{2,8}, {s} => {
G{s, bitalign{s, 8, .}}
}}
}
def maketabs{k, is, i, ...ts} = {
tab:*u8 = join{each{{s} => join{
each{{{E, t}} => t{s, range{k}} & 0xff, ts}
}, is}}
def ctab = length{ts}*i + *[k]u8~~tab
each{{j, {E,_}} => re_el{E,load{ctab,j}}, inds{ts}, ts}
}
def __shl{a:([16]u8), sh:([16]i8) if hasarch{'AARCH64'}} = a << [16]u8~~sh
def bitalign{{2,8,s}, 8, G if hasarch{'AARCH64'}} = G{s, {a:V=([16]u8)} => {
def {shuf1, shift1, shuf2, shift2} = maketabs{16, xrange{2,8}, s-2,
tup{u8, {s, r} => r *s>>3}, tup{i8, {s, r} => - r *s%8},
tup{u8, {s, r} => (r+1)*s>>3}, tup{i8, {s, r} => s - (r+1)*s%8},
}
def b = sel{[16]u8, a, shuf1} << shift1
def c = sel{[16]u8, a, shuf2} << shift2
(b | c) & V**cast_i{u8, tail{s}}
}}

View File

@ -3,6 +3,7 @@ include './cbqnDefs'
include './f64'
include './bitops'
include './mask'
include './bitalign'
def bitsel{VL, T, r, bits, e0, e1, len} = {
def bulk = VL/width{T}
@ -20,3 +21,20 @@ fn bitsel_i{VL,T}(r:*void, bits:*u64, e0:u64, e1:u64, len:u64) : void = {
def table{w} = each{bitsel_i{w, .}, tup{u8, u16, u32, u64}}
exportT{'simd_bitsel', table{arch_defvw}}
def padd{E, ptr:P, am} = { ptr = P~~(am + *E~~ptr) }
fn bitwiden_n_8(src:*void, dst:*void, csz:ux, cam:ux) : void = {
assert{cam>0}
assert{(csz>1) & (csz<8)}
def bulk = arch_defvw / 8
def V = [bulk]u8
def rbytes = cdiv{csz*cam, 8}
bitalign{tup{2,8,csz}, 8, {s, align} => {
@maskedLoop{bulk}(dst in tup{V,*u8~~dst} over cam) {
dst = align{load{*V~~src}}
padd{u8, src, bulk*s/8}
}
}}
}
export{'si_bitwiden_n_8', bitwiden_n_8}

View File

@ -132,11 +132,18 @@ static NOINLINE B zeroPadToCellBits0(B x, usz lr, usz cam, usz pcsz, usz ncsz) {
usz* rsh = arr_shAlloc(r, lr+1);
shcpy(rsh, SH(x), lr);
rsh[lr] = ncsz;
if (PIA(r)==0) goto decG_ret;
u64* xp = tyany_ptr(x);
// TODO widen 8/16-bit cells to 16/32 via cpyC(16|32)Arr
if (ncsz<=64 && (ncsz&(ncsz-1)) == 0) {
u64 tmsk = (1ull<<pcsz)-1;
#if SINGELI_SIMD
if (ncsz==8) {
si_bitwiden_n_8(xp, rp, pcsz, cam);
goto decG_ret;
}
#endif
#if FAST_PDEP
if (ncsz<32) {
assert(ncsz==8 || ncsz==16);
@ -173,6 +180,7 @@ static NOINLINE B zeroPadToCellBits0(B x, usz lr, usz cam, usz pcsz, usz ncsz) {
rp+= ncsz>>6;
}
}
decG_ret:;
decG(x);
return taga(r);
}