From f7dd900b3afcab76e08c8821297d96a0b38d4033 Mon Sep 17 00:00:00 2001 From: dzaima Date: Thu, 25 Jul 2024 17:33:44 +0300 Subject: [PATCH] =?UTF-8?q?faster=20=F0=9D=95=A8=E2=8A=8F=F0=9D=95=A9,=20i?= =?UTF-8?q?ncl.=20aarch64=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/select.c | 8 +- src/singeli/src/lut.singeli | 177 ++++++++++++++++++++++++++++++++ src/singeli/src/select.singeli | 182 ++++++++++++++++----------------- 3 files changed, 269 insertions(+), 98 deletions(-) create mode 100644 src/singeli/src/lut.singeli diff --git a/src/builtins/select.c b/src/builtins/select.c index bbb13530..70477e60 100644 --- a/src/builtins/select.c +++ b/src/builtins/select.c @@ -44,6 +44,8 @@ #if SINGELI #define SINGELI_FILE select #include "../utils/includeSingeli.h" + typedef bool (*SimdSelectFn)(void* w0, void* x0, void* r0, u64 wl, u64 xl); + #define SIMD_SELECT(WE, XL) ({ AUTO we_=(WE); AUTO xl_=(XL); assert(we_>=el_i8 && we_<=el_i32 && xl_>=3 && xl_<=6); si_select_tab[4*(we_-el_i8)+xl_-3]; }) #endif typedef void (*CFn)(void* r, ux rs, void* x, ux xs, ux data); @@ -182,9 +184,9 @@ B select_c2(B t, B w, B x) { u8 we = TI(w,elType); - #if SINGELI_AVX2 + #if SINGELI_AVX2 || SINGELI_NEON #define CPUSEL(W, NEXT) /*assumes 3≤xl≤6*/ \ - if (RARE(!avx2_select_tab[4*(we-el_i8)+xl-3](wp, xp, rp, wia, xn))) select_properError(w, x); + if (RARE(!SIMD_SELECT(we, xl)(wp, xp, rp, wia, xn))) select_properError(w, x); #else #define CASE(S, E) case S: for (usz i=i0; i {is:[ni_real]u8} => (tuple of vectors, totalling to ni_real elements)} +def lut_gen{mode, E, nt, ni} = 0 + +local def loader{G} = { + def proc{mem:*_} = { + def loader_mtg{V, i} = load{*V~~mem, i} + def loader_mtg{'offset', E, en} = proc{en + *E~~mem} + } + def proc{TG if kgen{TG}} = TG + def proc{vs if ktup{vs}} = { + def S = oneType{vs} + def loader_vtg{Q, i} = { + if (width{Q} == width{S}) Q~~select{vs,i} + else if (width{Q}*2 == width{S}) Q~~half{select{vs, i>>1}, i&1} + else assert{0, S, Q} + } + def loader_vtg{'offset', E, en} = { + def off = (width{E} * en) / width{S} + assert{off != length{vs}, vs, E, en} + if (off==0.5) proc{tup{half{select{vs,0},1}}} + else proc{slice{vs, off}} + } + } + def load_accepter{...vs} = G{proc{...vs}} +} + +def unzip_load{E, n, TG} = each{merge, unzip_load{E, n/2, TG}, unzip_load{E, n/2, TG{'offset', E, n}}} +def unzip_load{E, n, TG if width{E}*n <= arch_defvw} = each{tup, unzip{TG{[n]E, 0}, TG{[n]E, 1}}} + +def widen_tup{u32, is:([16]u8)} = tup{ # compiler will deduplicate all the repeated calls of this on the same is + widen{[8]u32, is}, + widen{[8]u32, shuf{[4]u32, is, 4b3232}}, +} + +def blend_halves{mode, E, nt, ni} = tup{nt, ni, loader{{TG} => { + def nth = nt/2 + def {(nth), (ni), prev} = lut_gen{mode, E, nth, ni} + def lo = prev{TG} + def hi = prev{TG{'offset', E, nth}} + def me{'raw'} = tup{lo, hi} + def me{'xor', it} = each{{a,b} => a{'xor',b}, me{'raw'}, it{'raw'}} + + # def me{'is', is:[_](u8)} = lo{'is', is} + # def me{...is} = { + # def [_]IE = oneType{is} + # def shl = if (IE==u8) __shl{u16} else __shl + # def bm = shl{is, width{IE}-1 - lb{nth}} + # each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}} + # } + def me{is:[_](u8) if hasarch{'X86_64'} and E==u32} = { + each{{l,h,m} => topBlend{l,h,m}, lo{is}, hi{is}, each{{c} => c << (31 - lb{nth}), widen_tup{E, is}}} + } + + def me{is:[_](u8) if hasarch{'X86_64'} and E==u8} = { + def bm = is <<{u16} (7 - lb{nth}) # TODO for outermost bit could do a cmpgt, increasing port diversity + each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}} + } + # def me{is:[_](u8) if hasarch{'AARCH64'} and E==u8} = { # only for one outermost blend + # def end = type{is}**(nth-1) + # def bm = is > end + # each{{l,h} => homBlend{l,h,bm}, lo{is & end}, hi{is & end}} + # } + + # TODO xor-ing could still be worth it for lower repeated levels where the index transformation can be deduped; and outermost can do a cmpgt + # if (mode=='c') hi{'xor', lo} + # def me{is:I=[_](u8) if mode=='c'} = { + # def bm = (is & I**lb{nth}) + # each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}} + # } +}}} + +def widen_inds{mode, E, nt0, ni0, sc} = match(lut_gen{mode, primtype{'u',width{E}/sc}, nt0*sc, ni0*sc}) { # e.g. sc==2: {a,b,c,d}[w,x,y,z] → {a0,a1, b0,b1, c0,c1, d0,d1}[w*2,w*2+1, x*2,x*2+1, y*2,y*2+1, z*2zw*2+1] + {{nt1, ni1, G}} => tup{nt1/sc, ni1/sc, loader{{TG} => { + def prev = G{TG} + def ni = ni1/sc + def WV = [ni]primtype{'u', 8*sc} + {is:([ni]u8)} => { + def isw = widen{WV, is} * WV**base{256, sc**sc} + WV**base{256, range{sc}} + each{re_el{E,.}, prev{re_el{u8, isw}}} + } + }}} + {x} => x +} + +def zip_halves{mode, E, nt, ni} = match(lut_gen{mode, w_h{E}, nt, ni}) { # e.g. {a,b,c,d}[w,x,y,z] → zip({a0,b0,c0,d0}[w,x,y,z], {a1,b1,c1,d1}[w,x,y,z]) + {{nt, ni, G}} => tup{nt, ni, loader{{TG} => { + # show{E, '→', w_h{E}} + def d = unzip_load{w_h{E}, nt, TG} + # lprintf{tup{'x0',d}} + def prevs = each{G, d} + def run_zip{zipper, is} = { + def {lo, hi} = each{{prev}=>prev{is}, prevs} + join{flip{each{zipper, lo, hi}}} + } + def me{is:([ni]u8)} = { + run_zip{mzip, is} + } + def me{is:([ni]u8) if hasarch{'AARCH64'} and E==u32 and ni==16} = { + def is2 = sel{[16]u8, is, make{[16]u8, 0,1,2,3, 8,9,10,11, 4,5,6,7, 12,13,14,15}} + run_zip{mzip128, is2} + } + def me{is:([ni]u8) if hasarch{'AARCH64'} and E==u64 and ni==16} = { + def is2 = sel{[16]u8, is, make{[16]u8, 0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15}} + run_zip{mzip128, is2} + } + def me{is:([ni]u8) if hasarch{'X86_64'} and E==u64 and ni==16} = { + def is2 = sel{[16]u8, is, make{[16]u8, 0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15}} + run_zip{mzip128, is2} + } + def me{is:([ni]u8) if hasarch{'X86_64'} and E==u16 and ni==32} = { + def is2 = shuf{[4]u64, is, 4b3120} + run_zip{mzip128, is2} + } + }}} + {x} => x +} + +# lut_gen order is very important! +def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=128 and ni<=32} = blend_halves{mode, E, 128, 32} # TODO probably don't +def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=64 and ni<=32} = blend_halves{mode, E, 64, 32} # TODO probably don't +def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=32 and ni<=32} = blend_halves{mode, E, 32, 32} + +# generate inds to utilize top bit of pshufb zeroing to replace vpblendvb with with vpor +def lut_gen{'c', E==u8, nt, ni if hasarch{'AVX2'} and nt<=64 and ni<=32} = { def vn=pow2_up{nt,16}/16; tup{vn*16, 32, loader{{TG} => { + def luts = each{{i} => [32]u8**TG{[16]u8, i}, range{vn}} + {is:([32]u8)} => { + + def bi = range{ceil_log2{vn}} + def bits = each{{o} => is <<{u16} (3-o), bi} # extract bits 0,1,2,3 (as many as needed) from 2b3210xxxx into top bit (x being bits used by pshufb) + + def top = [32]u8**128 + def isp = each{{i} => is | (top &~ tree_fold{&, each{{m, o} => if (bit{o,i}!=0) m else ~m, bits, bi}}), range{vn}} + + tup{tree_fold{|, each{sel{[16]u8,.,.}, luts, isp}}} + } +}}}} + +def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=16 and ni<=32} = tup{16, 32, loader{{TG} => { + lut:[32]u8 = [32]u8**TG{[16]u8, 0} + {is:([32]u8)} => tup{sel{[16]u8, lut, is}} +}}} + +def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=32 and ni<=16} = blend_halves{mode, E, 32, 16} # TODO probably don't +def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=16 and ni<=16} = blend_halves{mode, E, 16, 16} + +def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=8 and ni<=16} = tup{8, 16, loader{{TG} => { + def lut = TG{[8]u32, 0} + def me{'idxs', is:([16]u8)} = each{{wis} => tup{wis, sel{[8]u32, lut, wis}}, widen_tup{u32,is}} # TODO inline, or properly outline + def me{is:([16]u8)} = each{{{_,v}}=>v, me{'idxs', is}} +}}} + + + +# def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=128 and ni<=16} = blend_halves{mode, E, 128, 16} +def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=16*4 and ni<=16} = { def vn=pow2_up{nt,16}/16; tup{vn*16, 16, loader{{TG} => { # TODO could maybe accept nt==48 + def lut = each{TG{[16]u8, .}, range{vn}} + {is:([16]u8)} => tup{sel{lut, is}} +}}}} + +def lut_gen{mode, E, nt, ni if (E==u16 or E==u64)} = zip_halves{mode, E, nt, ni} +# def lut_gen{mode, E, nt, ni if (E==u16 or E==u64) and mode=='c'} = widen_inds{mode, E, nt, max{ni,16}, 2} +def lut_gen{mode, E, nt, ni if E==u32 and hasarch{'AARCH64'}} = zip_halves{mode, E, nt, ni} +# def lut_gen{mode, E, nt, ni if E==u32 and hasarch{'AARCH64'}} = widen_inds{mode, E, nt, ni, 2} +def lut_gen{mode, E, nt, ni if E==u64 and hasarch{'AARCH64'}} = widen_inds{mode, E, nt, ni, 2} + +def lut_gen{mode, E==u64, nt, ni if nt>16 and hasarch{'AVX2'}} = 0 diff --git a/src/singeli/src/select.singeli b/src/singeli/src/select.singeli index 1c471c55..902edc14 100644 --- a/src/singeli/src/select.singeli +++ b/src/singeli/src/select.singeli @@ -2,9 +2,12 @@ include './base' include './cbqnDefs' include './mask' include './bitops' +include './lut' include 'util/tup' -def {wrapChk, gather} +def has_sel = hasarch{'AVX2'} or hasarch{'AARCH64'} + +def gather if_inline (hasarch{'AVX2'}) { # def:T - masked original content # b:B - pointer to data to index; if width{B}= ty_u{xlf}}}) return{0} cw } -def wrapChk{cw0:VI, xlf, M} = wrapChk{cw0, VI,xlf, M} - - -if_inline (hasarch{'AVX2'}) { - -def storeExp{dst, ind, val, M, ext, rd, wl} = { - def s{M} = storeBatch{dst, ind, val, M} - if (ext==1 or not M{0}) s{M} - else if (ind*rd+rd <= wl) s{maskNone} - else { if (ind*rd < wl) s{maskAfter{wl & (rd-1)}}; return{1} } -} - -def shuf_select{ri, rd, TI, w, r, wl, xl, selx} = { - def VI = [ri]TI - def ext = ri/rd - xlf:= VI**cast_i{TI, xl} - @maskedLoop{ri}(cw0 in tup{VI,w}, M in 'm' over i to wl) { - cw:= wrapChk{cw0, VI,xlf, M} - is:= (if (ext>1) i< se{e*2, VI~~c, o}, - mzip128{c2, c2 + VI**1}, - 2*o + iota{2} - } +def masked_multistore{r0, vs, M, end} = { # returns bumped-forwards r + r:= r0 + def left = if (M{0}) { left:ux = M{'count'} } else 0 + def lastMaskedStore = makeOptBranch{M{0}, tup{oneType{vs}}, {c} => { + storeBatch{r, 0, c, maskAfter{left}} + end{} + }} + + each{{i, c: [k]_} => { + if (M{0}) { + if (i+1 == length{vs} or left1) i<0} = { - select{b,0}{each{bs{slice{b,1}, c, .}, x}} - } - - def i = iota{logv} - def vs = each{broadcast{VI, .}, nsel< VD~~bs{each{bb{c},i==0,vs}, c, xd} -} -def makeshuf{VI, VD, x0, logv} = { - x:= *VD~~x0 - def halves{v} = each{shuf{[4]u64, v, .}, tup{4b1010, 4b3232}} - def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-2)} - def readx{l==0,o} = shuf{[4]u64, load{x}, 4b1010} - def readx{l==1,o} = halves{load{x, o}} - xd:= readx{logv, 0} - makeselx{VI,VD,16,xd,logv, sel{[16]i8, ...}} -} -def makeperm{VI, VD, x0, logv} = { - x:= *VD~~x0 - def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-1)} - def readx{l==0,o} = load{x, o} - makeselx{[8]i32,VD,8, readx{logv, 0}, logv, sel{[8]i32, ...}} -} -fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { +fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { # TODO don't require SIMD? w:= *TI ~~ w0 x:= *TD ~~ x0 r:= *TD ~~ r0 - - def wd = width{TD}; def rd = rw/wd - def wi = width{TI}; def ri = rw/wi - def reg_select{sel,make}{l} = { - sel{ri, rd, TI, w, r, wl, xl, make{[ri]TI,[rd]TD, x,l}} + def wd = width{TD} + def wi = width{TI} + if (TI==i8) { # TODO some minimum bound on wl? + def trytab{nt} = match(lut_gen{'i', TD, nt, 2}) { + {{nt, ni, G}} => { + if (xl <= nt) { + def VI = [ni]TI + def xlf = VI**cast_i{TI, xl} + # show{TD, nt, ni, G} + # lprintf{'LUT of ', VI, ' ⊏ ', [nt]TD, ' with ', wl, ' ≡ ≠𝕨, ', xl, ' ≡ ≠𝕩'} + def lut = G{x} + @maskedLoop{ni}(w0 in tup{VI,w}, M in 'm' over wl) { + def w = wrapChk{w0, xlf, M} + def rs = lut{ty_u{w}} + r = masked_multistore{r, rs, M, {} => return{1}} + } + return{1} + } + trytab{nt+1} + } + {x} => {} + } + trytab{2} } - def shuf_select = reg_select{shuf_select, makeshuf} - def perm_select = reg_select{perm_select, makeperm} - if (wi==8 and wd==32 and xl*wd<=256 ) perm_select{0} - else if (wi==8 and wd==32 and xl*wd<=256<<1) perm_select{1} - else if (wi==8 and wd<=16 and xl*wd<=128 ) shuf_select{0} - else if (wi==8 and wd<=16 and xl*wd<=128<<1) shuf_select{1} - else if (wi==8 and wd<=16 and xl*wd<=128<<2) shuf_select{2} - else if (wi==8 and wd<= 8 and xl*wd<=128<<3) shuf_select{3} - else { + + if (hasarch{'AVX2'}) { def TIE = i32 def TDE = tern{wd<32, u32, TD} def bulk = rw / width{TDE} @@ -127,23 +87,52 @@ fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { def xlf = VI**cast_i{TIE, xl} @maskedLoop{bulk}(cw0 in tup{VI,w}, sr in tup{'g',r}, M in 'm' over wl) { - cw:= wrapChk{cw0, VI,xlf, M} + cw:= wrapChk{cw0, xlf, M} got:= gather{VD**0, x, cw, M} if (TDE!=TD) got&= VD**((1<=32) { + @for (r, w0 in w over wl) { + def w2 = wrap{w0} + if (rare{w2>=xl}) return{0} + r = load{x, w2} + } + } else { + def block_size = (1<<14) / (wi/8) + @for_blocks{block_size}(bl to wl) { + def {s,e} = bl + def {ok, min, max} = get_range{w, s, e} + if (not ok) return{0} + if (rare{max >= i64~~xl}) return{0} + if (min < 0) { + if (rare{min < -i64~~xl}) return{0} + # TODO use wrap_inds + @for (w, r over _ from s to e) r = load{x, wrap{w}} + } else { + @for (w, r over _ from s to e) r = load{x, promote{ux,ty_u{w}}} + } + } + } } 1 } -def select_fn{TI, TD} = select_fn{256, TI, TD} - -exportT{'avx2_select_tab', join{table{select_fn, - tup{i8, i16, i32}, # indices - tup{u8, u16, u32, u64}}}} # values +def select_fn{TI, TD} = select_fn{arch_defvw, TI, TD} +exportT{'si_select_tab', join{table{select_fn, + tup{i8, i16, i32}, # indices + tup{u8, u16, u32, u64}}} # values } -if_inline(hasarch{'AVX2'} or hasarch{'AARCH64'}) { + + +(if(has_sel) { fn simd_select_bool128(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { def TI = i8 def VI = [arch_defvw/8]TI @@ -180,4 +169,7 @@ if_inline(hasarch{'AVX2'} or hasarch{'AARCH64'}) { 1 } export{'simd_select_bool128', simd_select_bool128} -} +}) + + +