uCBQN/src/singeli/src/neon.singeli
dzaima d161f93a38 more Singeli refactoring
makes some conditions more strict
removes wrong el_m & el_s definitions
2024-05-26 04:25:54 +03:00

152 lines
7.7 KiB
Plaintext

def nvec{T} = isvec{T} and (width{T}==64 or width{T}==128)
def nvec{T,w} = nvec{T} and elwidth{T}==w
def nveci = genchk{nvec, isint}
def nvecs = genchk{nvec, issigned}
def nvecu = genchk{nvec, isunsigned}
def nvecf = genchk{nvec, isfloat}
def nty{T} = {
def q = quality{T}
merge{if (q=='i') 's' else q, fmtnat{width{T}}}
}
def nty{[_]T} = nty{T}
def ntyp{S, ...S2, T if w128{T}} = merge{S, 'q', ...S2, '_', nty{T}}
def ntyp{S, ...S2, T if w64{T}} = merge{S, ...S2, '_', nty{T}}
def ntyp0{S, T} = merge{S, '_', nty{T}}
def addwLo{a:T,b:T if w64i{T}} = emit{el_d{T}, ntyp{'vaddl', T}, a, b}
def subwLo{a:T,b:T if w64i{T}} = emit{el_d{T}, ntyp{'vsubl', T}, a, b}
def mulwLo{a:T,b:T if w64i{T}} = emit{el_d{T}, ntyp{'vmull', T}, a, b}
def mulwHi{a:T,b:T if w128i{T}} = emit{el_m{T}, ntyp0{'vmull_high', T}, a, b}
def mulw {a:T,b:T if w128{T}} = tup{mulwLo{half{a,0}, half{b,0}}, mulwHi{a,b}}
def shrn{a:T, s if w128i{T} and elwidth{T}>8} = { def H=el_h{T}; emit{H, ntyp0{'vshrn_n', T}, a, s} } # a>>s, narrowed
def shrm{a:T, s, d:T if nvecu{T}} = emit{T, ntyp{'vsri', '_n', T}, d, a, s} # (a>>s) | (d & (mask of new zeroes))
def shlm{a:T, s, d:T if nvecu{T}} = emit{T, ntyp{'vsli', '_n', T}, d, a, s} # (a<<s) | (d & (mask of new zeroes))
def bitBlend{f:T, t:T, m:M if nvec{T} and nvecu{M,elwidth{T}} and width{T}==width{M}} = emit{T, ntyp{'vbsl', T}, m, t, f}
def homBlend{f:T, t:T, m:M if nvec{M}} = bitBlend{f, t, m}
def addpw { x:T if nveci{T} and elwidth{T}<=32} = emit{el_m{T}, ntyp{'vpaddl', T}, x} # add pairwise widening
def addpwa{a:D==el_m{T}, x:T if nveci{T} and elwidth{T}<=32} = emit{D, ntyp{'vpadal', T}, a, x} # add pairwise widening + accumulate
def mla{a:T, x:T, y:T if nvec{T}} = emit{T, ntyp{'vmla', T}, a, x, y} # a + x*y
def mls{a:T, x:T, y:T if nvec{T}} = emit{T, ntyp{'vmls', T}, a, x, y} # a - x*y
def rbit{x:T if nvecu{T,8}} = emit{T, ntyp{'vrbit', T}, x}
def rev{(width{T}), x:[_]T} = x
def rev{16, x:T if nveci{T} and elwidth{T}<16} = emit{T, ntyp{'vrev16', T}, x} # reverse the order of elements in each w-bit window
def rev{32, x:T if nveci{T} and elwidth{T}<32} = emit{T, ntyp{'vrev32', T}, x}
def rev{64, x:T if nveci{T} and elwidth{T}<64} = emit{T, ntyp{'vrev64', T}, x}
def popc{x:T if nvecu{T,8}} = emit{T, ntyp{'vcnt', T}, x}
def clz{x:T if nvecu{T} and elwidth{T}<=32} = emit{T, ntyp{'vclz', T}, x}
def cls{x:T if nveci{T} and elwidth{T}<=32} = ty_u{T}~~emit{ty_s{T}, ntyp{'vcls', T}, x}
def fold_add {a:T=[_]E if nvec{T}} = emit{E, ntyp{'vaddv', T}, a}
def fold_addw{a:T=[_]E if nveci{T}} = emit{w_d{E}, ntyp{'vaddlv', T}, a}
def fold_min {a:T=[_]E if nvec{T} and ~nveci{T,64}} = emit{E, ntyp{'vminv', T}, a}
def fold_max {a:T=[_]E if nvec{T} and ~nveci{T,64}} = emit{E, ntyp{'vmaxv', T}, a}
def vfold{F, x:T if nvec{T} and ~nveci{T,64} and same{F, min}} = fold_min{x}
def vfold{F, x:T if nvec{T} and ~nveci{T,64} and same{F, max}} = fold_max{x}
def vfold{F, x:T if nvec{T} and same{F, +}} = fold_add{x}
def storeLow{ptr:*E, w, x:T=[_]E if nvec{T} and w<=64} = { def E=ty_u{w}; storeu{*E~~ptr, extract{re_el{E,T}~~x, 0}} }
def storeLow{ptr:*E, w, x:T=[_]E if nvec{T} and w==width{T}} = store{*T~~ptr, 0, x}
def loadLow{ptr:*V, w if w<=64} = { # implemented via a broadcast load
def L=re_el{ty_u{w}, V}
V ~~ emit{L, ntyp{'vld1', '_dup', L}, *ty_u{w}~~ptr}
}
def loadLow{ptr:*V, (width{V})} = load{ptr}
def undefPromote{T=[_]E, x:X=[_]E if w64{X} and w128{T}} = emit{T, ntyp{'vcombine', X}, x, x} # arm_neon.h doesn't actually provide a way to do this in a 0-instruction way. ¯\_(ツ)_/¯
def half{x:T, (0) if w128{T}} = emit{n_h{T}, ntyp0{'vget_low', T}, x}
def half{x:T, (1) if w128{T}} = emit{n_h{T}, ntyp0{'vget_high', T}, x}
def pair{a:T, b:T if w64{T}} = emit{n_d{T}, ntyp0{'vcombine', T}, a, b}
def copyLane{dst:D=[_]E, di, src:S=[_]E, si if w64{D} and nvec{S}} = emit{D, ntyp{'vcopy_lane', S}, dst, di, src, si}
def copyLane{dst:D=[_]E, di, src:S=[_]E, si if w128{D} and nvec{S}} = emit{D, ntyp{'vcopyq_lane', S}, dst, di, src, si}
def broadcastSel{x:T, i if nvec{T}} = emit{T, ntyp{'vdup', tern{w128{T},'_laneq','_lane'}, T}, x, i}
def vshl{a:T, b:T, n if knum{n}} = emit{T, ntyp{'vext', T}, a, b, n}
def zipLo{a:T, b:T if nvec{T}} = emit{T, ntyp{'vzip1', T}, a, b}
def zipHi{a:T, b:T if nvec{T}} = emit{T, ntyp{'vzip2', T}, a, b}
def zip{a:T, b:T if nvec{T}} = tup{zipLo{a,b}, zipHi{a,b}}
def packLo{x:T, y:T if nvec{T}} = { def H=el_s{T}; emit{H, ntyp{'vuzp1', H}, H~~x, H~~y} }
def packHi{x:T, y:T if nvec{T}} = { def H=el_s{T}; emit{H, ntyp{'vuzp2', H}, H~~x, H~~y} }
def packLo{{x, y}} = packLo{x, y}
def packHi{{x, y}} = packHi{x, y}
def trn1{x:T, y:T if nvec{T}} = emit{T, ntyp{'vtrn1', T}, x, y}
def trn2{x:T, y:T if nvec{T}} = emit{T, ntyp{'vtrn2', T}, x, y}
def sel{L, x:T, i:I if lvec{L,16,8} and w128{T} and nvec{I, 8}} = re_el{eltype{T}, emit{I, ntyp{'vqtbl1',I}, re_el{eltype{I},x}, ty_u{i}}}
local def eqqi{A, B} = isint{A} & (quality{A}==quality{B}) # equal quality integers
def cvt{T==f64, x:X=[k]_ if nveci{X,64}} = emit{[k]T, ntyp{'vcvt', '_f64', X}, x}
def cvt{T==i64, x:X=[k]_ if nvecf{X,64}} = emit{[k]T, ntyp{'vcvt', '_s64', X}, x}
def cvt{T==u64, x:X=[k]_ if nvecf{X,64}} = emit{[k]T, ntyp{'vcvt', '_u64', X}, x}
def widen{R=[_]RE, x:X=[_]XE if w64{X} and eqqi{RE,XE} and width{RE}==width{XE}*2} = emit{R, ntyp{'vmovl', X}, x}
def widen{R=[_]RE, x:X=[_]XE if w64{X} and eqqi{RE,XE} and width{RE}> width{XE}*2} = widen{R, widen{el_s{R}, x}}
def widen{R=[rn]RE, x:X=[xn]XE if w64{X} and isfloat{RE}!=isfloat{XE} and width{RE}>width{XE}} = cvt{RE, widen{[rn]to_w{XE,width{RE}}, x}}
def widen{R=[rn]RE, x:X=[xn]XE if w128{X} and xn>rn} = widen{R, half{x,0}}
def narrow{T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2< width{E}} = narrow{T, undefPromote{el_s{X}, narrow{w_h{E}, x}}}
def narrow{T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2==width{E}} = emit{el_h{X}, ntyp0{'vmovn', X}, x}
def narrow{T, x:X=[_]E if w128{X} and isfloat{T}!=isfloat{E} and width{T}<width{E}} = narrow{T, cvt{to_w{T, width{E}}, x}}
def narrowUpper{lowRes:L=[k]E, x:X if w64i{L} and w128{X} and el_d{L}==X} = emit{[k*2]E, ntyp0{'vmovn_high', X}, lowRes, x}
def narrowPair{a:T=[_]E, b:T} = narrowUpper{narrow{w_h{E}, a}, b}
def narrowPair{a:T=[_]E, b:T if isint{E}} = packLo{a, b}
def widenUpper{x:T if w128i{T}} = emit{el_m{T}, ntyp0{'vmovl_high', T}, x}
def widen{x:T if w128{T}} = tup{widen{el_m{T}, x}, widenUpper{x}}
def bitAny{x:T} = fold_max{re_el{u32, x}}!=0
def bitAll{x:T} = fold_min{re_el{u32, x}}==0xffff_ffff
def topAny{x:T if nvec{T}} = fold_min{ty_s{x}}<0
def topAll{x:T if nvec{T}} = fold_max{ty_s{x}}<0
def homAny{x:T if nvec{T}} = bitAny{x}
def homAll{x:T if nvec{T}} = bitAll{x}
def homMask{x:T=[k]E if nvecu{T} and width{E}>=k} = {
truncBits{k, fold_add{x & make{T, 1<<iota{k}}}}
}
def homMask{x:T==[16]u8} = {
t:= [8]u16~~sel{[16]u8, x, make{[16]u8, 0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}}
fold_add{t & make{[8]u16, (1<<iota{8})*0x0101}}
}
def homMask{a:T,b:T==[16]u8} = {
m:= make{[16]u8, 1<<(iota{16}&7)}
fold_add{addpw{addpw{addp{a&m, b&m}}}<<make{[4]u32,iota{4}*8}}
}
def homMask{a:T,b:T,c:T,d:T==[16]u8} = {
m:= make{[16]u8, 1<<(iota{16}&7)}
t1:= addp{a&m, b&m}
t2:= addp{c&m, d&m}
t3:= addp{t1, t2}
extract{[2]u64~~addp{t3,t3},0}
}
def homMask{...as={a0:[_]E, _, ..._} if width{E}>=32} = homMask{...each{{i}=>narrowPair{select{as,i*2},select{as,i*2+1}}, iota{length{as}/2}}}
def homMask{a:T,b:T=[k]E if k*2<=width{E}} = {
truncBits{k*2, fold_add{shrm{a,width{E}-k,b} & make{T, (1<<iota{k}) | (1<<(iota{k}+k))}}}
}
def andAllZero{x:T, y:T if nveci{T}} = ~bitAny{x&y}
def homMaskX{a:T=[k]E if E!=u64} = {
def h = width{E}/2
tup{h, truncBits{k*h, extract{[1]u64~~shrn{el_m{T}~~a, h}, 0}}}
}
def homMaskStoreF{p:*T, m:M, v:T if nveci{M} and nvec{T,elwidth{M}}} = store{p, 0, homBlend{load{p}, v, m}}