Selection with permutevar8x32

This commit is contained in:
Marshall Lochbaum 2022-11-29 17:48:24 -05:00
parent c17448ed71
commit 41f464795f
3 changed files with 32 additions and 12 deletions

View File

@ -102,6 +102,7 @@ def ceil{a:[4]f64} = emit{[4]f64, '_mm256_ceil_pd', a}
# conversion
def half{x:T, i & w256{T} & knum{i}} = [vcount{T}/2](eltype{T}) ~~ emit{[8]i16, '_mm256_extracti128_si256', v2i{x}, i}
def half{x:T, i==0 & w256{T}} = [vcount{T}/2](eltype{T}) ~~ emit{[8]i16, '_mm256_castsi256_si128', v2i{x}}
def pair{a:T,b:T & width{T}==128} = [vcount{T}*2](eltype{T}) ~~ emit{[8]i32, '_mm256_setr_m128i', a, b}
def pair{x} = pair{tupsel{0,x},tupsel{1,x}}

View File

@ -107,8 +107,8 @@ def maskstore{a:T, m:M, n, v & w256{eltype{T}, 64} & w256i{M, 64}} = emit{void,
def maskstoreF{p, m, n, x:T} = store{p, n, blendF{load{p,n}, x, m}}
def maskstoreF{p, m, n, x:T & width{eltype{T}}>=32} = maskstore{p,m,n,x}
def shl{S==[16]u8, x:T, n & w256{T}} = T ~~ emit{T, '_mm256_bslli_epi128', x, n}
def shr{S==[16]u8, x:T, n & w256{T}} = T ~~ emit{T, '_mm256_bsrli_epi128', x, n}
def shl{S==[16]u8, x:T, n & w256{T} & knum{n}} = T ~~ emit{T, '_mm256_bslli_epi128', x, n}
def shr{S==[16]u8, x:T, n & w256{T} & knum{n}} = T ~~ emit{T, '_mm256_bsrli_epi128', x, n}
def blend{L==[8]u16, a:T, b:T, m & w256{T} & knum{m}} = T ~~ emit{[16]i16, '_mm256_blend_epi16', v2i{a}, v2i{b}, m}
def blend{L==[8]u32, a:T, b:T, m & w256{T} & knum{m}} = T ~~ emit{[ 8]i32, '_mm256_blend_epi32', v2i{a}, v2i{b}, m}
@ -186,4 +186,4 @@ def ucvt{T, x:X & w256{X} & width{T}==width{eltype{X}}} = to_el{T, x} # TODO che
def cvt2{T, x:X & T==i32 & X==[4]f64} = emit{[4]i32, '_mm256_cvtpd_epi32', x}
def cvt2{T, x:X & T==f64 & X==[4]i32} = emit{[4]f64, '_mm256_cvtepi32_pd', x}
def cvt2{T, x:X & T==f64 & X==[4]i32} = emit{[4]f64, '_mm256_cvtepi32_pd', x}

View File

@ -24,6 +24,13 @@ def wrapChk{cw0, VI,xlf, M} = {
cw
}
def storeExp{dst, ind, val, M, ext, rd, wl} = {
def s{M} = storeBatch{dst, ind, val, M}
if (ext==1 or not M{0}) s{M}
else if (ind*rd+rd <= wl) s{maskNone}
else { if (ind*rd < wl) s{maskAfter{wl & (rd-1)}}; return{1} }
}
def shuf_select{ri, rd, TI, w, r, wl, xl, selx} = {
def VI = [ri]TI
def ext = ri/rd
@ -39,18 +46,24 @@ def shuf_select{ri, rd, TI, w, r, wl, xl, selx} = {
2*o + iota{2}
}
}
def se{e==ext, c, o} = {
io:= is+o
got:= selx{c}
def s{M} = storeBatch{r, io, got, M}
if (ext==1 or not M{0}) s{M}
else if (io*rd+rd <= wl) s{maskNone}
else { if (io*rd < wl) s{maskAfter{wl & (rd-1)}}; return{1} }
}
def se{e==ext, c, o} = storeExp{r, is+o, selx{c}, M, ext, rd, wl}
se{1, cw, 0}
}}
}
def perm_select{ri, rd, TI, w, r, wl, xl, selx} = {
def VI = [ri]TI
def ext = ri/rd
xlf:= broadcast{VI, cast_i{TI, xl}}
maskedLoop{ri, wl, {i, M} => {
cw:= wrapChk{loadBatch{w, i, VI}, VI,xlf, M}
is:= (if (ext>1) i<<lb{ext}; else i)
def part{o} = cvt{i8, [8]i32, shuf{[4]u64, cw, 4b3210+o}}
def se{o} = storeExp{r, is+o, selx{part{o}}, M, ext, rd, wl}
each{se, iota{ext}}
}}
}
def makesel{VI,VD, x0,logv} = {
x:= *VD~~x0
def halves{v} = each{bind{shuf, [4]u64, v}, tup{4b1010, 4b3232}}
@ -83,7 +96,13 @@ select{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
def shuf_select{l} = {
shuf_select{ri, rd, TI, w, r, wl, xl, makesel{[ri]TI,[rd]TD, x,l}}
}
if (wi==8 and wd<=32 and xl*wd<=128 ) { shuf_select{0} }
def perm_select{} = {
def VD = [rd]TD
xd:= load{*VD~~x}
perm_select{ri, rd, TI, w, r, wl, xl, {c}=>sel{VD, xd, c}}
}
if (wi==8 and wd==32 and xl*wd<=256 ) { perm_select{ } }
else if (wi==8 and wd<=16 and xl*wd<=128 ) { shuf_select{0} }
else if (wi==8 and wd<=16 and xl*wd<=128<<1) { shuf_select{1} }
else if (wi==8 and wd<=16 and xl*wd<=128<<2) { shuf_select{2} }
else if (wi==8 and wd<= 8 and xl*wd<=128<<3) { shuf_select{3} }