faster squeeze
This commit is contained in:
parent
d66be091be
commit
40a5c40bda
@ -4,6 +4,7 @@ def v2f{x:T if w256{T}} = [8]f32 ~~ x
|
||||
def v2d{x:T if w256{T}} = [4]f64 ~~ x
|
||||
|
||||
def undefPromote{T=[_]E, x:X=[_]E if w128{X} and w256{T}} = T~~emit{[32]u8, '_mm256_castsi128_si256', v2i{x}}
|
||||
def zeroPromote{T=[_]E, x:X=[_]E if w128{X} and w256{T}} = T~~emit{[32]u8, '_mm256_zextsi128_si256', v2i{x}}
|
||||
|
||||
# load & store
|
||||
def loadLow{ptr:*V, w if w256{V} and w<=128} = undefPromote{V, loadLow{*n_h{V} ~~ ptr, w}}
|
||||
|
||||
@ -24,7 +24,8 @@ def isptr {T} = istype{T} and same{typekind{T},'pointer'}
|
||||
def elwidth{T} = width{eltype{T}}
|
||||
|
||||
oper &~ andnot infix none 35
|
||||
def andnot{a, b if anyNum{a} and anyNum{b}} = a & ~b
|
||||
def andnot{a, b:T if anyNum{a} and isprim{T}} = a & ~b
|
||||
def andnot{a:T, b if isprim{T} and knum{b}} = a & ~T~~b
|
||||
|
||||
oper &- ({v:T,m:(u1)} => v & -promote{T,m}) infix left 35
|
||||
|
||||
@ -159,10 +160,10 @@ def lvec = match { {[n]T, n, (width{T})} => 1; {T, n, w} => 0 }
|
||||
# base cases
|
||||
def {
|
||||
absu,andAllZero,andnz,b_getBatch,blend,clmul,cvt,extract,fold_addw,half,
|
||||
homAll,homAny,homBlend,homMask,homMaskStore,homMaskStoreF,loadBatchBit,
|
||||
loadLow,make,maskStore,maskToHom,mulw,mulh,narrow,narrowPair,packHi,packLo,packQ,pair,pdep,
|
||||
homAll,homAny,bitAll,bitAny,homBlend,homMask,homMaskStore,homMaskStoreF,loadBatchBit,
|
||||
loadLow,make,maskStore,maskToHom,mulw,mulh,narrow,narrowTrunc,narrowPair,packHi,packLo,packQ,pair,pdep,
|
||||
pext,popcRand,sel,shl,shr,shuf,shuf16Hi,shuf16Lo,shufHalves,shufInd,storeLow,
|
||||
topBlend,topMask,topMaskStore,topMaskStoreF,unord,vfold,widen,
|
||||
topBlend,topMask,topMaskStore,topMaskStoreF,unord,vfold,widen,widenUpper,
|
||||
zipHi,zipLo
|
||||
}
|
||||
|
||||
@ -210,6 +211,7 @@ def pair{{a, b}} = pair{a, b}
|
||||
def widen{T, x:T} = x
|
||||
def narrow{T, x:[_]T} = x
|
||||
def undefPromote{T, x:T} = x
|
||||
def zeroPromote{T, x:T} = x
|
||||
def cvt{T, x:[_]T} = x
|
||||
|
||||
def broadcast{V=[_]T, v} = vec_broadcast{V, promote{T,v}}
|
||||
|
||||
@ -64,6 +64,9 @@ def loadBatch{ptr:*E0, i, [k]E1} = {
|
||||
def loadBatch {ptr:*E, {...ns}, T } = each{loadBatch {ptr, ., T }, ns}
|
||||
def storeBatch{ptr:*E, {...ns}, xs, M} = each{storeBatch{ptr, ., ., M}, ns, xs}
|
||||
|
||||
# TODO also similar homAny & use those more
|
||||
def homAll{(maskNone), ...xs} = homAll{...xs}
|
||||
def homAll{M, x:T if kgen{M}} = ~homAny{M{~x}} # TODO better
|
||||
|
||||
# "harmless" pointer cast that'll only cast void*
|
||||
def hCast{T,p} = assert{0, 'expected pointer with element',T,'or void but got ',p}
|
||||
|
||||
@ -73,10 +73,14 @@ def vshl{a:T, b:T, n if knum{n}} = emit{T, ntyp{'vext', T}, a, b, n}
|
||||
def zipLo{a:T, b:T if nvec{T}} = emit{T, ntyp{'vzip1', T}, a, b}
|
||||
def zipHi{a:T, b:T if nvec{T}} = emit{T, ntyp{'vzip2', T}, a, b}
|
||||
|
||||
def packLo{x:T, y:T if nvec{T}} = { def H=el_s{T}; emit{H, ntyp{'vuzp1', H}, H~~x, H~~y} }
|
||||
def packHi{x:T, y:T if nvec{T}} = { def H=el_s{T}; emit{H, ntyp{'vuzp2', H}, H~~x, H~~y} }
|
||||
def unzipLo{x:T, y:T if nvec{T}} = emit{T, ntyp{'vuzp1', T}, T~~x, T~~y}
|
||||
def unzipHi{x:T, y:T if nvec{T}} = emit{T, ntyp{'vuzp2', T}, T~~x, T~~y}
|
||||
def packLo{x:T, y:T if nvec{T}} = unzipLo{el_s{T}~~x, el_s{T}~~y}
|
||||
def packHi{x:T, y:T if nvec{T}} = unzipHi{el_s{T}~~x, el_s{T}~~y}
|
||||
def packLo{{x, y}} = packLo{x, y}
|
||||
def packHi{{x, y}} = packHi{x, y}
|
||||
def shufInd{x:T, y:T, {...is} if nvec{T,32} and same{is, 2*range{vcount{T}}}} = T~~unzipLo{x,y}
|
||||
def shufInd{x:T, y:T, {...is} if nvec{T,32} and same{is, 1+2*range{vcount{T}}}} = T~~unzipHi{x,y}
|
||||
|
||||
def trn1{x:T, y:T if nvec{T}} = emit{T, ntyp{'vtrn1', T}, x, y}
|
||||
def trn2{x:T, y:T if nvec{T}} = emit{T, ntyp{'vtrn2', T}, x, y}
|
||||
@ -85,7 +89,7 @@ def sel{L, x:T, i:I if lvec{L,16,8} and w128{T} and nvec{I, 8}} = re_el{eltype{T
|
||||
|
||||
|
||||
|
||||
local def eqqi{A, B} = isint{A} & (quality{A}==quality{B}) # equal quality integers
|
||||
local def eqqi{A, B} = isint{A} and isint{B} and quality{A}==quality{B} # equal quality integers
|
||||
|
||||
def cvt{T==f64, x:X=[k]_ if nveci{X,64}} = emit{[k]T, ntyp{'vcvt', '_f64', X}, x}
|
||||
def cvt{T==i64, x:X=[k]_ if nvecf{X,64}} = emit{[k]T, ntyp{'vcvt', '_s64', X}, x}
|
||||
@ -96,9 +100,10 @@ def widen{R=[_]RE, x:X=[_]XE if w64{X} and eqqi{RE,XE} and width{RE}> width{XE}*
|
||||
def widen{R=[rn]RE, x:X=[xn]XE if w64{X} and isfloat{RE}!=isfloat{XE} and width{RE}>width{XE}} = cvt{RE, widen{[rn]to_w{XE,width{RE}}, x}}
|
||||
def widen{R=[rn]RE, x:X=[xn]XE if w128{X} and xn>rn} = widen{R, half{x,0}}
|
||||
|
||||
def narrow{T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2< width{E}} = narrow{T, undefPromote{el_s{X}, narrow{w_h{E}, x}}}
|
||||
def narrow{T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2==width{E}} = emit{el_h{X}, ntyp0{'vmovn', X}, x}
|
||||
def narrow{T, x:X=[_]E if w128{X} and isfloat{T}!=isfloat{E} and width{T}<width{E}} = narrow{T, cvt{to_w{T, width{E}}, x}}
|
||||
def narrow {T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2< width{E}} = narrow{T, undefPromote{el_s{X}, narrow{w_h{E}, x}}}
|
||||
def narrowTrunc{T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2==width{E}} = emit{el_h{X}, ntyp0{'vmovn', X}, x}
|
||||
def narrow {T, x:X=[_]E if w128{X} and eqqi{T,E} and width{T}*2==width{E}} = narrowTrunc{T, x}
|
||||
def narrow {T, x:X=[_]E if w128{X} and isfloat{T}!=isfloat{E} and width{T}<width{E}} = narrow{T, cvt{to_w{T, width{E}}, x}}
|
||||
|
||||
def narrowUpper{lowRes:L=[k]E, x:X if w64i{L} and w128{X} and el_d{L}==X} = emit{[k*2]E, ntyp0{'vmovn_high', X}, lowRes, x}
|
||||
def narrowPair{a:T=[_]E, b:T} = narrowUpper{narrow{w_h{E}, a}, b}
|
||||
|
||||
@ -1,136 +1,130 @@
|
||||
include './debug'
|
||||
include './base'
|
||||
include './mask'
|
||||
include './cbqnDefs'
|
||||
include 'util/tup'
|
||||
include './vecfold'
|
||||
|
||||
def preserve_negative_zero = 0
|
||||
def is_sNaN{x:[_]u64} = inRangeLen{x<<1, (0xFFE<<52)+2, (1<<52)-2}
|
||||
def is_sNaN{x:[2]u64 if hasarch{'X86_64'} and not hasarch{'SSE4.2'}} = { # avoiding i64 comparisons
|
||||
def nan = unord{[2]f64~~x, [2]f64~~x}
|
||||
def qnan = [2]u64~~([4]u32**0xFFF8_0000 == ([4]u32~~x | [4]u32**0x8000_0000))
|
||||
nan &~ qnan
|
||||
}
|
||||
def any_sNaN{M, ...xs} = homAny{M{tree_fold{|, each{is_sNaN, xs}}}}
|
||||
|
||||
# SSE2 versions avoid any 64-bit integer comparsions
|
||||
def anySNaN{M, x:[_](u64)} = {
|
||||
homAny{inRangeLen{M{x}<<1, (0xFFE<<52)+2, (1<<52)-2}}
|
||||
}
|
||||
def anySNaN{M, x:T==[2]u64 if hasarch{'X86_64'} and not hasarch{'SSE4.2'}} = {
|
||||
topAny{M{andnot{unord{[2]f64~~x, [2]f64~~x}, [2]u64~~([4]u32**0xFFF8_0000 == ([4]u32~~x | [4]u32**0x8000_0000))}}}
|
||||
}
|
||||
def anyNonChar{M, x:[_](u64)} = homAny{M{~inRangeLen{x, cbqn_c32Tag{}<<48, 1<<48}}}
|
||||
def anyNonChar{M, x:T=[_]_ if hasarch{'X86_64'}} = {
|
||||
def any_nonC32{M, x:[_](u64)} = homAny{M{~inRangeLen{x, cbqn_c32Tag{}<<48, 1<<48}}}
|
||||
def any_nonC32{M, x:T=[_]_ if hasarch{'X86_64'}} = {
|
||||
def H = re_el{u32, T}
|
||||
def ne = H~~x != H**cast_i{u32, cbqn_c32Tag{}<<16}
|
||||
topAny{M{T~~ne}}
|
||||
}
|
||||
def any_nonC32{(maskNone), x:[k]u64, y:[k]u64} = {
|
||||
def T32 = [k*2]u32
|
||||
def hi = shufInd{T32~~x, T32~~y, match(k) {
|
||||
{2} => tup{1,3,5,7}
|
||||
{4} => tup{1,3,9,11,5,7,13,15} # all odd indices, in the order that vshufps can handle
|
||||
}}
|
||||
anyne{hi, T32**cast_i{u32, cbqn_c32Tag{}<<16}, maskNone}
|
||||
}
|
||||
|
||||
|
||||
def cvtNarrow{DE, x:[_]XE if width{DE}==width{XE}} = cvt{DE, x}
|
||||
def cvtNarrow{DE, x:[_]XE if width{DE}< width{XE}} = narrow{DE, x}
|
||||
def cvtWiden{ [_]DE, x:[_]XE if width{DE}==width{XE}} = cvt{DE, x}
|
||||
def cvtWiden{D=[_]DE, x:[_]XE if width{DE}> width{XE}} = widen{D, x}
|
||||
|
||||
fn squeeze{vw, X, CHR, B}(x0:*void, len:ux) : u32 = {
|
||||
assert{len>0}
|
||||
fn squeeze{vw, X, CHR, B if CHR or X==i32 or X==i16 or X==i8 or X==f64}(x0:*void, len:ux) : u32 = {
|
||||
def bulk = vw / width{X}
|
||||
def XV = [bulk]X
|
||||
def E = tern{X==f64, u32, ty_u{X}}
|
||||
def EV2 = [bulk*2]E
|
||||
def EV = tern{(width{E}*bulk == 64) & hasarch{'X86_64'}, EV2, [bulk]E}
|
||||
def xb = tup{XV,*X~~x0}
|
||||
|
||||
# fold with either Max or Bitwise Or, truncating/zero-extending to TE
|
||||
def foldTotal{TE, x:[_]T} = cast_i{TE, vfold{|, x}}
|
||||
def foldTotal{TE, x:[_]T if hasarch{'AARCH64'}} = {
|
||||
# fold with either Max or Bitwise Or, truncating/zero-extending elements to TE
|
||||
def fold_total{TE, x:[_]T} = cast_i{TE, vfold{|, x}}
|
||||
def fold_total{TE, x:[_]T if hasarch{'AARCH64'}} = {
|
||||
if (width{T}!=64) vfold{max, x}
|
||||
else if (width{TE}==64 and bulk==2) cast_i{TE, half{x,0} | half{x,1}}
|
||||
else vfold{max, narrow{TE, x}}
|
||||
}
|
||||
def int_acc{T} = {
|
||||
minv:= T**0
|
||||
maxv:= T**0
|
||||
def me{M, minc, maxc} = {
|
||||
minv = min{minv, M{minc}}
|
||||
maxv = max{maxv, M{maxc}}
|
||||
}
|
||||
def me{} = {
|
||||
mint:= ty_u{vfold{min, minv}}
|
||||
maxt:= ty_u{vfold{max, maxv}} &~ 1
|
||||
cast_i{u32, tern{mint==0, maxt, max{maxt, -mint-1} | 2}}
|
||||
}
|
||||
def me{M, vs} = {
|
||||
minc:= zeroPromote{T, tree_fold{min, vs}} # could pack pairs in v to low & high halves, but an extra min costs the same or less than an insert
|
||||
maxc:= zeroPromote{T, tree_fold{max, vs}}
|
||||
me{M, minc, maxc}
|
||||
}
|
||||
}
|
||||
|
||||
# show{XV, EV, CHR, B}
|
||||
xp:= *X~~x0
|
||||
r1:= EV**0
|
||||
if (CHR) { # c8, c16, c32
|
||||
def hw = width{E}/2
|
||||
@maskedLoop{bulk}(xv in tup{XV,xp}, M in 'm' over len) {
|
||||
c:= EV~~xv
|
||||
if (X!=u16) r1|= M{c} # for u64, just accept the garbage top 32 bits and deal with them at the end
|
||||
if (CHR) { # c16/c32/B → char
|
||||
mt:= XV**0
|
||||
@muLoop{bulk, 2}(xs in xb, M in 'm' over len) {
|
||||
def orx = M{tree_fold{|, xs}}
|
||||
if (B) {
|
||||
if (anyNonChar{M, c}) return{3}
|
||||
if (any_nonC32{M, ...xs}) return{3}
|
||||
} else {
|
||||
if (anynePositive{EV**((1<<hw-1)<<hw) & c, EV**0, M}) return{lb{hw}-2}
|
||||
def bad = if (hasarch{'AARCH64'}) bitAny{if (length{xs}==2) packHi{...xs} else packHi{orx,orx}}
|
||||
else ~andAllZero{orx, ~XV**maxvalue{w_h{X}}}
|
||||
if (bad) return{lb{width{X}}-3}
|
||||
}
|
||||
mt|= orx
|
||||
}
|
||||
r2:= foldTotal{u32, r1}
|
||||
if (X>u32 and r2>=65536) return{2}
|
||||
if (X>u16 and r2>=256) return{1}
|
||||
def tot_max = fold_total{u32, if (B) mt & XV**32w2b1 else mt}
|
||||
# lprintf{tup{'x0', XV & make{XV, cycle{vcount{MT}, tup{32w0xf, 0}}}}, tot_max}
|
||||
if (X>u32 and tot_max>=65536) return{2}
|
||||
if (X>u16 and tot_max>=256) return{1}
|
||||
0
|
||||
} else { # i8, i16, i32, f64
|
||||
if (X==i8) { # i8
|
||||
@maskedLoop{bulk}(v0 in tup{XV,xp}, M in 'm' over len) {
|
||||
if (anynePositive{EV**0xfe & EV~~v0, EV**0, M}) return{2}
|
||||
}
|
||||
0
|
||||
} else { # i16, i32, f64
|
||||
def case_B = makeOptBranch{B, tup{ux}, {iCont} => {
|
||||
def XU = [bulk]u64
|
||||
@maskedLoop{bulk, iCont}(xv in tup{XV,xp}, M in 'm' over len) {
|
||||
v:= XU ~~ xv
|
||||
if (anySNaN{M, v}) return{0xffff_fffe} # not even a number
|
||||
}
|
||||
return{0xffff_ffff} # float
|
||||
}}
|
||||
|
||||
def getAcc{EV=[_]E, x:[_]T} = {
|
||||
((EV ** ~E~~1) & EV~~x) ^ EV~~(x >> (width{T}-1))
|
||||
}
|
||||
|
||||
if (isint{X}) { # i16, i32
|
||||
@muLoop{bulk, 1}(v0 in tup{XV,xp}, M in 'm' over len) {
|
||||
r1|= M{tree_fold{|, each{{v} => getAcc{EV, v}, v0}}}
|
||||
}
|
||||
} else { # f64
|
||||
r2:= EV2**0
|
||||
@muLoop{
|
||||
bulk, hasarch{'AARCH64'}+1,
|
||||
{} => { r1 = half{r2,0}|half{r2,1} }
|
||||
}(v0 in tup{XV,xp}, M in 'm' over is to len) {
|
||||
def int = {
|
||||
def {int, wdn} = {
|
||||
if (hasarch{'AARCH64'} and length{is}==2) {
|
||||
def intp = narrowPair{...each{cvt{i64,.}, v0}}
|
||||
def wdn = each{cvt{f64,.}, widen{intp}}
|
||||
tup{intp, wdn}
|
||||
} else {
|
||||
def ints = each{{v} => cvtNarrow{ty_s{E}, v}, v0}
|
||||
def wdn = each{{v} => cvtWiden{XV, v}, ints}
|
||||
def intp = match (...ints) {
|
||||
{i:[(bulk)]_} => i
|
||||
{i if hasarch{'X86_64'} and not hasarch{'AVX2'}} => i
|
||||
{i:T=[(bulk)]_, j:T} => pair{ints}
|
||||
}
|
||||
tup{intp, wdn}
|
||||
}
|
||||
}
|
||||
|
||||
def conv{x} = tern{preserve_negative_zero, ty_u{x}, x}
|
||||
def as = each{conv, v0}
|
||||
def bs = each{conv, wdn}
|
||||
def cond = {
|
||||
if (length{is}==1) anynePositive{...as, ...bs, M}
|
||||
else ~homAll{tree_fold{&, each{==, as, bs}}}
|
||||
}
|
||||
if (cond) { # is any not an integer
|
||||
if (B) case_B{select{is, 0}} # if B, need to give an even more special result
|
||||
else return{0xffff_ffff} # else, not integer => float
|
||||
}
|
||||
int
|
||||
}
|
||||
def acc = if (length{is}==2) r2 else r1
|
||||
|
||||
acc|= M{getAcc{type{acc}, int}}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def f = foldTotal{E, r1}
|
||||
cast_i{u32, f}
|
||||
} else if (X==i32 or X==i16) {
|
||||
# TODO simpler path for len≤unr×bulk?
|
||||
# TODO aarch64 path?
|
||||
# TODO SSE2 i32 is extremely slow due to lack of min/max
|
||||
def EH = w_h{X}
|
||||
def acc = int_acc{XV}
|
||||
@muLoop{bulk, 4}(xs in xb, M in 'm' over len) {
|
||||
minc:= tree_fold{min, xs}
|
||||
maxc:= tree_fold{max, xs}
|
||||
if (homAny{M{(minc < XV**minvalue{EH}) | (maxc > XV**maxvalue{EH})}}) return{0xffff_ffff}
|
||||
acc{M, minc, maxc}
|
||||
}
|
||||
acc{}
|
||||
} else if (X==i8) {
|
||||
@muLoop{bulk, 2}(xs in xb, M in 'm' over len) {
|
||||
if (~andAllZero{M{tree_fold{|, xs}}, XV ** -2}) return{2}
|
||||
}
|
||||
0
|
||||
} else if (X==f64) {
|
||||
def case_B = makeBranch{tup{ux}, {bulkCont} => {
|
||||
def i0 = bulkCont*bulk
|
||||
x:= i0 + *u64~~x0
|
||||
if (B) @muLoop{bulk, 2}(xs in tup{[bulk]u64, x}, M in 'm' over len-i0) {
|
||||
if (any_sNaN{M, ...xs}) return{0xffff_fffe} # not even a number
|
||||
}
|
||||
return{0xffff_ffff} # float
|
||||
}}
|
||||
|
||||
def acc = int_acc{re_el{i32, XV}}
|
||||
@muLoop{bulk, 2}(xs in xb, M in 'm' over is to len) {
|
||||
if (hasarch{'X86_64'}) {
|
||||
def ns = each{narrow{i32,.}, xs}
|
||||
if (homAny{M{tree_fold{|, each{{ns,x} => widen{XV,ns}!=x, ns, xs}}}}) case_B{select{is, 0}}
|
||||
acc{M, ns}
|
||||
} else {
|
||||
assert{hasarch{'AARCH64'}}
|
||||
def k = length{xs}
|
||||
def int = each{cvt{i64,.}, xs}
|
||||
def int32 = if (k==2) packLo{...int} else packLo{...int, [bulk]i64**0}
|
||||
def wd = each{{G} => cvt{f64,G{int32}}, slice{tup{widen{[bulk]i64,.}, widenUpper}, 0, k}}
|
||||
|
||||
if (~homAll{M, tree_fold{&, each{==, wd, xs}}}) case_B{select{is, 0}}
|
||||
acc{M, tup{int32}}
|
||||
}
|
||||
}
|
||||
acc{}
|
||||
} else {
|
||||
assert{0}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user