faster 𝕨⊏𝕩, incl. aarch64 support

This commit is contained in:
dzaima 2024-07-25 17:33:44 +03:00
parent afa3353b35
commit f7dd900b3a
3 changed files with 269 additions and 98 deletions

View File

@ -44,6 +44,8 @@
#if SINGELI
#define SINGELI_FILE select
#include "../utils/includeSingeli.h"
typedef bool (*SimdSelectFn)(void* w0, void* x0, void* r0, u64 wl, u64 xl);
#define SIMD_SELECT(WE, XL) ({ AUTO we_=(WE); AUTO xl_=(XL); assert(we_>=el_i8 && we_<=el_i32 && xl_>=3 && xl_<=6); si_select_tab[4*(we_-el_i8)+xl_-3]; })
#endif
typedef void (*CFn)(void* r, ux rs, void* x, ux xs, ux data);
@ -182,9 +184,9 @@ B select_c2(B t, B w, B x) {
u8 we = TI(w,elType);
#if SINGELI_AVX2
#if SINGELI_AVX2 || SINGELI_NEON
#define CPUSEL(W, NEXT) /*assumes 3≤xl≤6*/ \
if (RARE(!avx2_select_tab[4*(we-el_i8)+xl-3](wp, xp, rp, wia, xn))) select_properError(w, x);
if (RARE(!SIMD_SELECT(we, xl)(wp, xp, rp, wia, xn))) select_properError(w, x);
#else
#define CASE(S, E) case S: for (usz i=i0; i<i1; i++) ((E*)rp)[i] = ((E*)xp+off)[ip[i]]; break
@ -211,7 +213,7 @@ B select_c2(B t, B w, B x) {
}
#endif
#if SINGELI_SIMD
#if SINGELI_AVX2 || SINGELI_NEON
bool bool_use_simd = we==el_i8 && xl==0 && xia<=128;
#define BOOL_SPECIAL(W) \

177
src/singeli/src/lut.singeli Normal file
View File

@ -0,0 +1,177 @@
def __shl{(u16)}{a:T, b} = T~~(re_el{u16,a}<<b) # for x86's lack of u8 shift
def broadcast{[(n*2)]E, x:[n]E} = pair{x, x}
def pow2_up{v, least} = max{least, 1<<ceil_log2{v}} # least ⌈ ⌈⌾(2⊸⋆⁼) v
# make a LUT of at least nt elements in tab, to be indexed by [ni_real≥ni]u8
# E must be unsigned
# mode is a hint on expected usage:
# mode=='i': same table is reused for many index batches (list ⊏ list)
# mode=='c': same index is reused across many LUTs (list⊸⊏˘ mat)
# mode=='o': neither table nor indices are reused (mat ⊏˘ mat)
# returns {nt_real≥nt, ni_real≥ni, {tab:*E} => {is:[ni_real]u8} => (tuple of vectors, totalling to ni_real elements)}
def lut_gen{mode, E, nt, ni} = 0
local def loader{G} = {
def proc{mem:*_} = {
def loader_mtg{V, i} = load{*V~~mem, i}
def loader_mtg{'offset', E, en} = proc{en + *E~~mem}
}
def proc{TG if kgen{TG}} = TG
def proc{vs if ktup{vs}} = {
def S = oneType{vs}
def loader_vtg{Q, i} = {
if (width{Q} == width{S}) Q~~select{vs,i}
else if (width{Q}*2 == width{S}) Q~~half{select{vs, i>>1}, i&1}
else assert{0, S, Q}
}
def loader_vtg{'offset', E, en} = {
def off = (width{E} * en) / width{S}
assert{off != length{vs}, vs, E, en}
if (off==0.5) proc{tup{half{select{vs,0},1}}}
else proc{slice{vs, off}}
}
}
def load_accepter{...vs} = G{proc{...vs}}
}
def unzip_load{E, n, TG} = each{merge, unzip_load{E, n/2, TG}, unzip_load{E, n/2, TG{'offset', E, n}}}
def unzip_load{E, n, TG if width{E}*n <= arch_defvw} = each{tup, unzip{TG{[n]E, 0}, TG{[n]E, 1}}}
def widen_tup{u32, is:([16]u8)} = tup{ # compiler will deduplicate all the repeated calls of this on the same is
widen{[8]u32, is},
widen{[8]u32, shuf{[4]u32, is, 4b3232}},
}
def blend_halves{mode, E, nt, ni} = tup{nt, ni, loader{{TG} => {
def nth = nt/2
def {(nth), (ni), prev} = lut_gen{mode, E, nth, ni}
def lo = prev{TG}
def hi = prev{TG{'offset', E, nth}}
def me{'raw'} = tup{lo, hi}
def me{'xor', it} = each{{a,b} => a{'xor',b}, me{'raw'}, it{'raw'}}
# def me{'is', is:[_](u8)} = lo{'is', is}
# def me{...is} = {
# def [_]IE = oneType{is}
# def shl = if (IE==u8) __shl{u16} else __shl
# def bm = shl{is, width{IE}-1 - lb{nth}}
# each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}}
# }
def me{is:[_](u8) if hasarch{'X86_64'} and E==u32} = {
each{{l,h,m} => topBlend{l,h,m}, lo{is}, hi{is}, each{{c} => c << (31 - lb{nth}), widen_tup{E, is}}}
}
def me{is:[_](u8) if hasarch{'X86_64'} and E==u8} = {
def bm = is <<{u16} (7 - lb{nth}) # TODO for outermost bit could do a cmpgt, increasing port diversity
each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}}
}
# def me{is:[_](u8) if hasarch{'AARCH64'} and E==u8} = { # only for one outermost blend
# def end = type{is}**(nth-1)
# def bm = is > end
# each{{l,h} => homBlend{l,h,bm}, lo{is & end}, hi{is & end}}
# }
# TODO xor-ing could still be worth it for lower repeated levels where the index transformation can be deduped; and outermost can do a cmpgt
# if (mode=='c') hi{'xor', lo}
# def me{is:I=[_](u8) if mode=='c'} = {
# def bm = (is & I**lb{nth})
# each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}}
# }
}}}
def widen_inds{mode, E, nt0, ni0, sc} = match(lut_gen{mode, primtype{'u',width{E}/sc}, nt0*sc, ni0*sc}) { # e.g. sc==2: {a,b,c,d}[w,x,y,z] → {a0,a1, b0,b1, c0,c1, d0,d1}[w*2,w*2+1, x*2,x*2+1, y*2,y*2+1, z*2zw*2+1]
{{nt1, ni1, G}} => tup{nt1/sc, ni1/sc, loader{{TG} => {
def prev = G{TG}
def ni = ni1/sc
def WV = [ni]primtype{'u', 8*sc}
{is:([ni]u8)} => {
def isw = widen{WV, is} * WV**base{256, sc**sc} + WV**base{256, range{sc}}
each{re_el{E,.}, prev{re_el{u8, isw}}}
}
}}}
{x} => x
}
def zip_halves{mode, E, nt, ni} = match(lut_gen{mode, w_h{E}, nt, ni}) { # e.g. {a,b,c,d}[w,x,y,z] → zip({a0,b0,c0,d0}[w,x,y,z], {a1,b1,c1,d1}[w,x,y,z])
{{nt, ni, G}} => tup{nt, ni, loader{{TG} => {
# show{E, '→', w_h{E}}
def d = unzip_load{w_h{E}, nt, TG}
# lprintf{tup{'x0',d}}
def prevs = each{G, d}
def run_zip{zipper, is} = {
def {lo, hi} = each{{prev}=>prev{is}, prevs}
join{flip{each{zipper, lo, hi}}}
}
def me{is:([ni]u8)} = {
run_zip{mzip, is}
}
def me{is:([ni]u8) if hasarch{'AARCH64'} and E==u32 and ni==16} = {
def is2 = sel{[16]u8, is, make{[16]u8, 0,1,2,3, 8,9,10,11, 4,5,6,7, 12,13,14,15}}
run_zip{mzip128, is2}
}
def me{is:([ni]u8) if hasarch{'AARCH64'} and E==u64 and ni==16} = {
def is2 = sel{[16]u8, is, make{[16]u8, 0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15}}
run_zip{mzip128, is2}
}
def me{is:([ni]u8) if hasarch{'X86_64'} and E==u64 and ni==16} = {
def is2 = sel{[16]u8, is, make{[16]u8, 0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15}}
run_zip{mzip128, is2}
}
def me{is:([ni]u8) if hasarch{'X86_64'} and E==u16 and ni==32} = {
def is2 = shuf{[4]u64, is, 4b3120}
run_zip{mzip128, is2}
}
}}}
{x} => x
}
# lut_gen order is very important!
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=128 and ni<=32} = blend_halves{mode, E, 128, 32} # TODO probably don't
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=64 and ni<=32} = blend_halves{mode, E, 64, 32} # TODO probably don't
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=32 and ni<=32} = blend_halves{mode, E, 32, 32}
# generate inds to utilize top bit of pshufb zeroing to replace vpblendvb with with vpor
def lut_gen{'c', E==u8, nt, ni if hasarch{'AVX2'} and nt<=64 and ni<=32} = { def vn=pow2_up{nt,16}/16; tup{vn*16, 32, loader{{TG} => {
def luts = each{{i} => [32]u8**TG{[16]u8, i}, range{vn}}
{is:([32]u8)} => {
def bi = range{ceil_log2{vn}}
def bits = each{{o} => is <<{u16} (3-o), bi} # extract bits 0,1,2,3 (as many as needed) from 2b3210xxxx into top bit (x being bits used by pshufb)
def top = [32]u8**128
def isp = each{{i} => is | (top &~ tree_fold{&, each{{m, o} => if (bit{o,i}!=0) m else ~m, bits, bi}}), range{vn}}
tup{tree_fold{|, each{sel{[16]u8,.,.}, luts, isp}}}
}
}}}}
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=16 and ni<=32} = tup{16, 32, loader{{TG} => {
lut:[32]u8 = [32]u8**TG{[16]u8, 0}
{is:([32]u8)} => tup{sel{[16]u8, lut, is}}
}}}
def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=32 and ni<=16} = blend_halves{mode, E, 32, 16} # TODO probably don't
def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=16 and ni<=16} = blend_halves{mode, E, 16, 16}
def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=8 and ni<=16} = tup{8, 16, loader{{TG} => {
def lut = TG{[8]u32, 0}
def me{'idxs', is:([16]u8)} = each{{wis} => tup{wis, sel{[8]u32, lut, wis}}, widen_tup{u32,is}} # TODO inline, or properly outline
def me{is:([16]u8)} = each{{{_,v}}=>v, me{'idxs', is}}
}}}
# def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=128 and ni<=16} = blend_halves{mode, E, 128, 16}
def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=16*4 and ni<=16} = { def vn=pow2_up{nt,16}/16; tup{vn*16, 16, loader{{TG} => { # TODO could maybe accept nt==48
def lut = each{TG{[16]u8, .}, range{vn}}
{is:([16]u8)} => tup{sel{lut, is}}
}}}}
def lut_gen{mode, E, nt, ni if (E==u16 or E==u64)} = zip_halves{mode, E, nt, ni}
# def lut_gen{mode, E, nt, ni if (E==u16 or E==u64) and mode=='c'} = widen_inds{mode, E, nt, max{ni,16}, 2}
def lut_gen{mode, E, nt, ni if E==u32 and hasarch{'AARCH64'}} = zip_halves{mode, E, nt, ni}
# def lut_gen{mode, E, nt, ni if E==u32 and hasarch{'AARCH64'}} = widen_inds{mode, E, nt, ni, 2}
def lut_gen{mode, E, nt, ni if E==u64 and hasarch{'AARCH64'}} = widen_inds{mode, E, nt, ni, 2}
def lut_gen{mode, E==u64, nt, ni if nt>16 and hasarch{'AVX2'}} = 0

View File

@ -2,9 +2,12 @@ include './base'
include './cbqnDefs'
include './mask'
include './bitops'
include './lut'
include 'util/tup'
def {wrapChk, gather}
def has_sel = hasarch{'AVX2'} or hasarch{'AARCH64'}
def gather
if_inline (hasarch{'AVX2'}) {
# def:T - masked original content
# b:B - pointer to data to index; if width{B}<elwidth{T}, padding bytes are garbage read after wanted position
@ -19,106 +22,63 @@ if_inline (hasarch{'AVX2'}) {
}
}
def wrapChk{cw0, VI,xlf, M} = {
def wrapChk{cw0:VI, xlf, M} = {
cw:= cw0 + (xlf & VI~~(cw0<VI**0))
if (homAny{M{ty_u{cw} >= ty_u{xlf}}}) return{0}
cw
}
def wrapChk{cw0:VI, xlf, M} = wrapChk{cw0, VI,xlf, M}
if_inline (hasarch{'AVX2'}) {
def storeExp{dst, ind, val, M, ext, rd, wl} = {
def s{M} = storeBatch{dst, ind, val, M}
if (ext==1 or not M{0}) s{M}
else if (ind*rd+rd <= wl) s{maskNone}
else { if (ind*rd < wl) s{maskAfter{wl & (rd-1)}}; return{1} }
}
def shuf_select{ri, rd, TI, w, r, wl, xl, selx} = {
def VI = [ri]TI
def ext = ri/rd
xlf:= VI**cast_i{TI, xl}
@maskedLoop{ri}(cw0 in tup{VI,w}, M in 'm' over i to wl) {
cw:= wrapChk{cw0, VI,xlf, M}
is:= (if (ext>1) i<<lb{ext} else i)
def se{e, c, o} = {
c2:= shuf{[4]u64, c+c, 4b3120}
each{
{c,o} => se{e*2, VI~~c, o},
mzip128{c2, c2 + VI**1},
2*o + iota{2}
}
def masked_multistore{r0, vs, M, end} = { # returns bumped-forwards r
r:= r0
def left = if (M{0}) { left:ux = M{'count'} } else 0
def lastMaskedStore = makeOptBranch{M{0}, tup{oneType{vs}}, {c} => {
storeBatch{r, 0, c, maskAfter{left}}
end{}
}}
each{{i, c: [k]_} => {
if (M{0}) {
if (i+1 == length{vs} or left<k) lastMaskedStore{c}
left-= k
}
def se{e==ext, c, o} = storeExp{r, is+o, selx{c}, M, ext, rd, wl}
se{1, cw, 0}
}
storeBatch{r, 0, c, maskNone}
r+= k
}, inds{vs}, vs}
r
}
def perm_select{ri, rd, TI, w, r, wl, xl, selx} = {
def VI = [ri]TI
def ext = ri/rd
xlf:= VI**cast_i{TI, xl}
@maskedLoop{ri}(cw0 in tup{VI,w}, M in 'm' over i to wl) {
cw:= wrapChk{cw0, VI,xlf, M}
is:= (if (ext>1) i<<lb{ext} else i)
def part{o} = widen{[8]i32, re_el{i8, shuf{[4]u64, cw, 4b3210+o}}}
def se{o} = storeExp{r, is+o, selx{part{o}}, M, ext, rd, wl}
each{se, iota{ext}}
}
}
def makeselx{VI, VD, nsel, xd, logv, cshuf} = {
def bblend {m}{{f,t}} = homBlend{f, t, type{f} ~~ m}
def bblendn{m}{{t,f}} = bblend{m}{tup{f,t}}
def bb{c}{f, v} = (if (f) bblendn{c<v} else bblend{(c&v)==v})
def bs{b, c, x} = cshuf{x, c}
def bs{b, c, x if length{b}>0} = {
select{b,0}{each{bs{slice{b,1}, c, .}, x}}
}
def i = iota{logv}
def vs = each{broadcast{VI, .}, nsel<<reverse{i}}
{c} => VD~~bs{each{bb{c},i==0,vs}, c, xd}
}
def makeshuf{VI, VD, x0, logv} = {
x:= *VD~~x0
def halves{v} = each{shuf{[4]u64, v, .}, tup{4b1010, 4b3232}}
def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-2)}
def readx{l==0,o} = shuf{[4]u64, load{x}, 4b1010}
def readx{l==1,o} = halves{load{x, o}}
xd:= readx{logv, 0}
makeselx{VI,VD,16,xd,logv, sel{[16]i8, ...}}
}
def makeperm{VI, VD, x0, logv} = {
x:= *VD~~x0
def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-1)}
def readx{l==0,o} = load{x, o}
makeselx{[8]i32,VD,8, readx{logv, 0}, logv, sel{[8]i32, ...}}
}
fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { # TODO don't require SIMD?
w:= *TI ~~ w0
x:= *TD ~~ x0
r:= *TD ~~ r0
def wd = width{TD}; def rd = rw/wd
def wi = width{TI}; def ri = rw/wi
def reg_select{sel,make}{l} = {
sel{ri, rd, TI, w, r, wl, xl, make{[ri]TI,[rd]TD, x,l}}
def wd = width{TD}
def wi = width{TI}
if (TI==i8) { # TODO some minimum bound on wl?
def trytab{nt} = match(lut_gen{'i', TD, nt, 2}) {
{{nt, ni, G}} => {
if (xl <= nt) {
def VI = [ni]TI
def xlf = VI**cast_i{TI, xl}
# show{TD, nt, ni, G}
# lprintf{'LUT of ', VI, ' ⊏ ', [nt]TD, ' with ', wl, ' ≡ ≠𝕨, ', xl, ' ≡ ≠𝕩'}
def lut = G{x}
@maskedLoop{ni}(w0 in tup{VI,w}, M in 'm' over wl) {
def w = wrapChk{w0, xlf, M}
def rs = lut{ty_u{w}}
r = masked_multistore{r, rs, M, {} => return{1}}
}
return{1}
}
trytab{nt+1}
}
{x} => {}
}
trytab{2}
}
def shuf_select = reg_select{shuf_select, makeshuf}
def perm_select = reg_select{perm_select, makeperm}
if (wi==8 and wd==32 and xl*wd<=256 ) perm_select{0}
else if (wi==8 and wd==32 and xl*wd<=256<<1) perm_select{1}
else if (wi==8 and wd<=16 and xl*wd<=128 ) shuf_select{0}
else if (wi==8 and wd<=16 and xl*wd<=128<<1) shuf_select{1}
else if (wi==8 and wd<=16 and xl*wd<=128<<2) shuf_select{2}
else if (wi==8 and wd<= 8 and xl*wd<=128<<3) shuf_select{3}
else {
if (hasarch{'AVX2'}) {
def TIE = i32
def TDE = tern{wd<32, u32, TD}
def bulk = rw / width{TDE}
@ -127,23 +87,52 @@ fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
def xlf = VI**cast_i{TIE, xl}
@maskedLoop{bulk}(cw0 in tup{VI,w}, sr in tup{'g',r}, M in 'm' over wl) {
cw:= wrapChk{cw0, VI,xlf, M}
cw:= wrapChk{cw0, xlf, M}
got:= gather{VD**0, x, cw, M}
if (TDE!=TD) got&= VD**((1<<wd)-1)
sr{got}
}
} else {
def ix = ty_s{ux}
def wrap{w0:T if quality{T}=='i'} = {
def w1 = promote{ix, w0}
tern{w1<0, xl + ux~~w1, ux~~w1}
}
if (wi>=32) {
@for (r, w0 in w over wl) {
def w2 = wrap{w0}
if (rare{w2>=xl}) return{0}
r = load{x, w2}
}
} else {
def block_size = (1<<14) / (wi/8)
@for_blocks{block_size}(bl to wl) {
def {s,e} = bl
def {ok, min, max} = get_range{w, s, e}
if (not ok) return{0}
if (rare{max >= i64~~xl}) return{0}
if (min < 0) {
if (rare{min < -i64~~xl}) return{0}
# TODO use wrap_inds
@for (w, r over _ from s to e) r = load{x, wrap{w}}
} else {
@for (w, r over _ from s to e) r = load{x, promote{ux,ty_u{w}}}
}
}
}
}
1
}
def select_fn{TI, TD} = select_fn{256, TI, TD}
exportT{'avx2_select_tab', join{table{select_fn,
tup{i8, i16, i32}, # indices
tup{u8, u16, u32, u64}}}} # values
def select_fn{TI, TD} = select_fn{arch_defvw, TI, TD}
exportT{'si_select_tab', join{table{select_fn,
tup{i8, i16, i32}, # indices
tup{u8, u16, u32, u64}}} # values
}
if_inline(hasarch{'AVX2'} or hasarch{'AARCH64'}) {
(if(has_sel) {
fn simd_select_bool128(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
def TI = i8
def VI = [arch_defvw/8]TI
@ -180,4 +169,7 @@ if_inline(hasarch{'AVX2'} or hasarch{'AARCH64'}) {
1
}
export{'simd_select_bool128', simd_select_bool128}
}
})