use partial application

This commit is contained in:
dzaima 2023-07-22 18:19:31 +03:00
parent 033f3de6b9
commit 959614c785
17 changed files with 43 additions and 44 deletions

View File

@ -34,7 +34,7 @@ def broadcast{T, v & w256i{T, 64}} = emit{T, '_mm256_set1_epi64x',promote{eltype
def broadcast{T, v & w256f{T, 64}} = emit{T, '_mm256_set1_pd', v} def broadcast{T, v & w256f{T, 64}} = emit{T, '_mm256_set1_pd', v}
def broadcast{T, v & w256f{T, 32}} = emit{T, '_mm256_set1_ps', v} def broadcast{T, v & w256f{T, 32}} = emit{T, '_mm256_set1_ps', v}
local def makeGen{T,s,x} = emit{T, s, ...each{{c}=>promote{eltype{T},c}, x}} local def makeGen{T,s,x} = emit{T, s, ...each{promote{eltype{T}, .}, x}}
def make{T, ...xs & w256f{T,64} & tuplen{xs}== 4} = makeGen{T, '_mm256_setr_pd', xs} def make{T, ...xs & w256f{T,64} & tuplen{xs}== 4} = makeGen{T, '_mm256_setr_pd', xs}
def make{T, ...xs & w256f{T,32} & tuplen{xs}== 8} = makeGen{T, '_mm256_setr_ps', xs} def make{T, ...xs & w256f{T,32} & tuplen{xs}== 8} = makeGen{T, '_mm256_setr_ps', xs}
def make{T, ...xs & w256i{T,64} & tuplen{xs}== 4} = makeGen{T, '_mm256_setr_epi64x', xs} def make{T, ...xs & w256i{T,64} & tuplen{xs}== 4} = makeGen{T, '_mm256_setr_epi64x', xs}

View File

@ -31,7 +31,7 @@ def loadu{p:T & elwidth{T}==8} = load{p}
def storeu{p:T, v:eltype{T} & elwidth{T}==8} = store{p, v} def storeu{p:T, v:eltype{T} & elwidth{T}==8} = store{p, v}
def reinterpret{T, x:X & T==X} = x def reinterpret{T, x:X & T==X} = x
def exportN{f, ...ns} = each{{n} => export{n, f}, ns} def exportN{f, ...ns} = each{export{.,f}, ns}
def exportT{name, fs} = { v:*type{tupsel{0,fs}} = fs; export{name, v} } def exportT{name, fs} = { v:*type{tupsel{0,fs}} = fs; export{name, v} }
@ -220,7 +220,7 @@ def inRangeExcl{x:T, start, end} = inRangeLen{x, start, end-start} # ∊ [start;
def broadcast{T, v & isprim{T}} = v def broadcast{T, v & isprim{T}} = v
def iota{n & knum{n}} = range{n} def iota{n & knum{n}} = range{n}
def collect{vars,begin,end,iter & knum{begin} & knum{end}} = { def collect{vars,begin,end,iter & knum{begin} & knum{end}} = {
each{{i} => iter{i, vars}, range{end-begin}+begin} each{iter{., vars}, range{end-begin}+begin}
} }
def broadcast{n, v & knum{n}} = each{{_}=>v, range{n}} def broadcast{n, v & knum{n}} = each{{_}=>v, range{n}}

View File

@ -65,16 +65,15 @@ fn max_scan{T, up}(x:*T, len:u64) : void = {
def getsel{...x} = assert{'shuffling not supported', show{...x}} def getsel{...x} = assert{'shuffling not supported', show{...x}}
if (hasarch{'AVX2'}) { if (hasarch{'AVX2'}) {
def getsel{h:H & lvec{H, 16, 8}} = { def getsel{h:H & lvec{H, 16, 8}} = {
v := pair{h,h} sel{H, pair{h,h}, .}
{i} => sel{H, v, i}
} }
def getsel{v:V & lvec{V, 32, 8}} = { def getsel{v:V & lvec{V, 32, 8}} = {
def H = n_h{V} def H = n_h{V}
vtop := V**(vcount{V}/2) vtop := V**(vcount{V}/2)
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}} hs := each{shuf{[4]u64, v, .}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop} {i} => homBlend{...each{sel{H,.,i}, hs}, V~~i<vtop}
} }
def getsel{v:V & lvec{V, 8, 32}} = { {i} => sel{V, v, i} } def getsel{v:V & lvec{V, 8, 32}} = sel{V, v, .}
} }
# Move evens to half 0 and odds to half 1 # Move evens to half 0 and odds to half 1
@ -176,11 +175,11 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
# We'll subtract 1 when indexing so the initial 0 isn't needed # We'll subtract 1 when indexing so the initial 0 isn't needed
tui:*i8 = copy{maxu, 0}; i:T = 0 tui:*i8 = copy{maxu, 0}; i:T = 0
@for (tui over promote{u64,nu}) { i = load{t, load{w, i}}; tui = i } @for (tui over promote{u64,nu}) { i = load{t, load{w, i}}; tui = i }
def tv = bind{load, *V~~tui} def tv = load{*V~~tui, .}
ui = tv{0} ui = tv{0}
if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232} if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232}
ui = shuf{[4]u64, ui, 4b1010} ui = shuf{[4]u64, ui, 4b1010}
if (nu > vl) ui2 = each{bind{shuf, [4]u64, tv{1}}, tup{4b1010, 4b3232}} if (nu > vl) ui2 = each{shuf{[4]u64, tv{1}, .}, tup{4b1010, 4b3232}}
} }
# Popcount on 8-bit values # Popcount on 8-bit values
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} } def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
@ -229,7 +228,7 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def wd = width{T} def wd = width{T}
def I = if (wd<32) u8 else u32; def wi = width{I} def I = if (wd<32) u8 else u32; def wi = width{I}
def lanes = hasarch{'AVX2'} & (I==u8) def lanes = hasarch{'AVX2'} & (I==u8)
def isub = wd/wi; def bb = bind{base,1<<wi} def isub = wd/wi; def bb = base{1<<wi, .}
def vl = 256/wd; def svl = vl>>lanes def vl = 256/wd; def svl = vl>>lanes
def V = [vl]T def V = [vl]T
def U = [vl](ty_u{T}) def U = [vl](ty_u{T})
@ -250,7 +249,7 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
if (ex>=1 and wn >= svl) { if (ex>=1 and wn >= svl) {
--gap # Allows subtracting < instead of adding <= --gap # Allows subtracting < instead of adding <=
def un = uninterleave def un = uninterleave
def tr_half{a, b} = each{bind{shufHalves,a,b}, tup{16b20, 16b31}} def tr_half{a, b} = each{shufHalves{a,b,.}, tup{16b20, 16b31}}
def un{{a,b}} = tr_half{un{a},un{b}} def un{{a,b}} = tr_half{un{a},un{b}}
if (not lanes) tupsel{1,wv} = load{wg, 1} if (not lanes) tupsel{1,wv} = load{wg, 1}
wv = un{wv} wv = un{wv}
@ -313,7 +312,7 @@ def bin_search_branchless{up, w, wn, x, n, res, rtype} = {
l0 := wn + 1 l0 := wn + 1
# Take a list of indices in x/res to allow unrolling # Take a list of indices in x/res to allow unrolling
def search{inds} = { def search{inds} = {
xs:= each{bind{load,x}, inds} # Values xs:= each{load{x,.}, inds} # Values
ss:= each{{_}=>ws, inds} # Initial lower bound ss:= each{{_}=>ws, inds} # Initial lower bound
l := l0; h := undefined{u64} # Interval size l, same for all values l := l0; h := undefined{u64} # Interval size l, same for all values
while ((h=l/2) > 0) { while ((h=l/2) > 0) {
@ -380,6 +379,6 @@ exportT{
'si_saturate', 'si_saturate',
each{{a}=>saturate{...a}, merge{ each{{a}=>saturate{...a}, merge{
tup{tup{i16,i8}, tup{i32,i8}, tup{i32,i16}}, tup{tup{i16,i8}, tup{i32,i8}, tup{i32,i16}},
join{table{bind{tup,f64}, tup{i8,i16,i32}, tup{1,0}}} join{table{tup{f64, ...}, tup{i8,i16,i32}, tup{1,0}}}
}} }}
} }

View File

@ -99,5 +99,5 @@ def loadBatchBit{T, x:*u64, is & ktup{is}} = {
# assert{count*len <= 64} # assert{count*len <= 64}
# bits:= b_getBatchLo{count*len, x, tupsel{0,is}} # bits:= b_getBatchLo{count*len, x, tupsel{0,is}}
# @collect(i to len) spreadBits{T, truncBits{count, bits>>(i*count)}} # @collect(i to len) spreadBits{T, truncBits{count, bits>>(i*count)}}
each{{i} => loadBatchBit{T, x, i}, is} each{loadBatchBit{T, x, .}, is}
} }

View File

@ -26,6 +26,6 @@ fn bitsel_i{VL,T}(r:*void, bits:*u64, e0:u64, e1:u64, len:u64) : void = {
bitsel{VL, T, *T~~r, bits, trunc{T,e0}, trunc{T,e1}, len} bitsel{VL, T, *T~~r, bits, trunc{T,e0}, trunc{T,e1}, len}
} }
def table{w} = each{{T} => bitsel_i{w, T}, tup{u8, u16, u32, u64}} def table{w} = each{bitsel_i{w, .}, tup{u8, u16, u32, u64}}
exportT{'simd_bitsel', table{arch_defvw}} exportT{'simd_bitsel', table{arch_defvw}}

View File

@ -90,13 +90,13 @@ def any2bit{VT, unr, op, wS, wV, xS, xV, dst:*u64, len:(Size)} = {
fn aa2bit{VT, unr, op}(dst:*u64, wr:*void, xr:*void, len:Size) : void = { fn aa2bit{VT, unr, op}(dst:*u64, wr:*void, xr:*void, len:Size) : void = {
wv:= *VT~~wr; ws:= *eltype{VT}~~wr wv:= *VT~~wr; ws:= *eltype{VT}~~wr
xv:= *VT~~xr; xs:= *eltype{VT}~~xr xv:= *VT~~xr; xs:= *eltype{VT}~~xr
any2bit{VT, unr, op, {i}=>load{ws,i}, {i}=>load{wv,i}, {i}=>load{xs,i}, {i}=>load{xv,i}, dst, len} any2bit{VT, unr, op, load{ws,.}, load{wv,.}, load{xs,.}, load{xv,.}, dst, len}
} }
fn as2bit{VT, unr, op}(dst:*u64, wr:*void, x:u64, len:Size) : void = { fn as2bit{VT, unr, op}(dst:*u64, wr:*void, x:u64, len:Size) : void = {
wv:= *VT~~wr; ws:= *eltype{VT}~~wr wv:= *VT~~wr; ws:= *eltype{VT}~~wr
xv:= VT**pathAS{dst, len, eltype{VT}, op, x} xv:= VT**pathAS{dst, len, eltype{VT}, op, x}
any2bit{VT, unr, op, {i}=>load{ws,i}, {i}=>load{wv,i}, {i}=>x, {i}=>xv, dst, len} any2bit{VT, unr, op, load{ws,.}, load{wv,.}, {i}=>x, {i}=>xv, dst, len}
} }
fn bitAA{bitop}(dst:*u64, wr:*void, xr:*void, len:Size) : void = { fn bitAA{bitop}(dst:*u64, wr:*void, xr:*void, len:Size) : void = {

View File

@ -39,7 +39,7 @@ rcsh_data:*i8 = join{join{each{get_shuf_data, rcsh_vals}}}
# first 4 shuffle vectors for 11≤𝕨≤61; only uses the low half of the input # first 4 shuffle vectors for 11≤𝕨≤61; only uses the low half of the input
def rcsh4_dom = replicate{bind{>=,64}, replicate{fact_tab==1, fact_inds}} def rcsh4_dom = replicate{bind{>=,64}, replicate{fact_tab==1, fact_inds}}
rcsh4_dat:*i8 = join{join{each{{wv}=>get_shuf_data{wv, 4}, rcsh4_dom}}} rcsh4_dat:*i8 = join{join{each{get_shuf_data{., 4}, rcsh4_dom}}}
rcsh4_lkup:*i8 = shiftright{0, scan{+, fold{|, table{==, rcsh4_dom, iota{64}}}}} rcsh4_lkup:*i8 = shiftright{0, scan{+, fold{|, table{==, rcsh4_dom, iota{64}}}}}
def read_shuf_vecs{l, ellw:u64, shp:P} = { # tuple of byte selectors in 1<<ellw def read_shuf_vecs{l, ellw:u64, shp:P} = { # tuple of byte selectors in 1<<ellw
@ -57,7 +57,7 @@ def read_shuf_vecs{l, ellw:u64, shp:P} = { # tuple of byte selectors in 1<<ellw
def sh = each{{v}=>{r:=v}, l**V**0} def sh = each{{v}=>{r:=v}, l**V**0}
def tlen{e} = cdiv{l, e} # Length for e bytes, rounded up def tlen{e} = cdiv{l, e} # Length for e bytes, rounded up
def set{i} = { tupsel{i,sh} = each{bind{load,shp},i} } def set{i} = { tupsel{i,sh} = each{load{shp,.}, i} }
def ext{e} = { def ext{e} = {
def m = tlen{2*e}; def n = tlen{e} # m<n def m = tlen{2*e}; def n = tlen{e} # m<n
if (ellw <= lb{e}) set{slice{iota{n},m}} if (ellw <= lb{e}) set{slice{iota{n},m}}
@ -94,9 +94,9 @@ if (hasarch{'AVX2'}) {
def l = tuplen{sh} def l = tuplen{sh}
def h = l>>1 def h = l>>1
def fs{v, s} = gen{sel{[16]i8, v, s}} def fs{v, s} = gen{sel{[16]i8, v, s}}
a := shuf{[4]u64, x, 4b1010}; each{bind{fs,a}, slice{sh,0,h}} a := shuf{[4]u64, x, 4b1010}; each{fs{a,.}, slice{sh,0,h}}
if (l%2) fs{x, tupsel{h, sh}} if (l%2) fs{x, tupsel{h, sh}}
b := shuf{[4]u64, x, 4b3232}; each{bind{fs,b}, slice{sh,-h}} b := shuf{[4]u64, x, 4b3232}; each{fs{b,.}, slice{sh,-h}}
} }
def get_rep_iter{V, wv==2}{x, gen} = { def get_rep_iter{V, wv==2}{x, gen} = {
@ -105,7 +105,7 @@ if (hasarch{'AVX2'}) {
} }
def get_rep_iter{V==[4]u64, wv} = { def get_rep_iter{V==[4]u64, wv} = {
def step = 4 def step = 4
def sh = each{bind{base,4}, get_shufs{step, wv, wv}} def sh = each{base{4,.}, get_shufs{step, wv, wv}}
{x, gen} => each{{s}=>gen{shuf{V, x, s}}, sh} {x, gen} => each{{s}=>gen{shuf{V, x, s}}, sh}
} }

View File

@ -48,7 +48,7 @@ fn copy{X, R}(x: *void, r: *void, l:u64, xRaw: *void) : void = {
@maskedLoop{vcount{V64}}(sr in tup{'g',rp}, x in tup{V64,xp} over cdiv{l,64}) sr{x} @maskedLoop{vcount{V64}}(sr in tup{'g',rp}, x in tup{V64,xp} over cdiv{l,64}) sr{x}
} else if (X==u1) { } else if (X==u1) {
# show{'X==u1', X, R} # show{'X==u1', X, R}
copyFromBits{[bulk]R, {T, i} => loadBatchBit{T, xp, i}, r, l} copyFromBits{[bulk]R, loadBatchBit{., xp, .}, r, l}
} else if (R==u1) { } else if (R==u1) {
# show{'R==u1', X, R} # show{'R==u1', X, R}
def XU = ty_u{XV} def XU = ty_u{XV}

View File

@ -116,7 +116,7 @@ def runner{u, R, F} = {
} }
# homAny, topAny already give masked vals; anyne doesn't, and ~andAllZero assumes no masking # homAny, topAny already give masked vals; anyne doesn't, and ~andAllZero assumes no masking
def runChecks_any{F, vals} = { F{tree_fold{|, each{{c}=>tupsel{1,c}, vals}}} } def runChecks_any{F, vals} = { F{tree_fold{|, each{tupsel{1,.}, vals}}} }
def runChecks{type=='homAny', vals, M} = runChecks_any{homAny, vals} def runChecks{type=='homAny', vals, M} = runChecks_any{homAny, vals}
def runChecks{type=='topAny', vals, M} = runChecks_any{topAny, vals} def runChecks{type=='topAny', vals, M} = runChecks_any{topAny, vals}
def runChecks{type=='none', vals, M} = 0 def runChecks{type=='none', vals, M} = 0
@ -133,7 +133,7 @@ def runChecks{type=='anyne', vals, M} = {
def arithProcess{F, run, overflow, M, is, cw, cx, TY} = { def arithProcess{F, run, overflow, M, is, cw, cx, TY} = {
def {values, checks} = flip{each{{w1, x1} => run{F, M, w1, x1}, cw, cx}} def {values, checks} = flip{each{{w1, x1} => run{F, M, w1, x1}, cw, cx}}
def ctype = oneVal{each{{c}=>tupsel{0,c}, checks}} def ctype = oneVal{each{tupsel{0,.}, checks}}
if (rare{runChecks{ctype, checks, M}}) overflow{tupsel{0,is}*vcount{TY}} if (rare{runChecks{ctype, checks, M}}) overflow{tupsel{0,is}*vcount{TY}}
each{{c} => TY~~c, values} each{{c} => TY~~c, values}
} }

View File

@ -165,7 +165,7 @@ def broadcast{T, x & nvec{T}} = emit{T, ntyp{'vdup', '_n', T}, x}
def make{T, ...xs & nvec{T} & tuplen{xs}==vcount{T}} = { def make{T, ...xs & nvec{T} & tuplen{xs}==vcount{T}} = {
def TE = eltype{T} def TE = eltype{T}
load{*T ~~ *TE ~~ each{{c}=>promote{eltype{T},c}, xs}, 0} load{*T ~~ *TE ~~ each{promote{eltype{T},.}, xs}, 0}
} }
def make{T, x & nvec{T} & istup{x}} = make{T, ...x} def make{T, x & nvec{T} & istup{x}} = make{T, ...x}
def iota{T & nvec{T}} = make{T, ...iota{vcount{T}}} def iota{T & nvec{T}} = make{T, ...iota{vcount{T}}}

View File

@ -89,7 +89,7 @@ def scan_assoc{op, a:T} = {
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b} l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}} op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
} }
def scan_plus = bind{scan_assoc, +} def scan_plus = scan_assoc{+, .}
# Associative scan # Associative scan
fn avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = { fn avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {

View File

@ -17,7 +17,7 @@ def ctzi{x} = promote{u64, ctz{x}} # Count trailing zeros, as index
def findFirst{C, M, F, ...v1} = { def findFirst{C, M, F, ...v1} = {
def exit = makelabel{} def exit = makelabel{}
def args = undef{M{...each{{c}=>tupsel{0,c}, v1}}} def args = undef{M{...each{tupsel{0, .}, v1}}}
def am = tuplen{tupsel{0,v1}} def am = tuplen{tupsel{0,v1}}
each{{last, ...v2} => { each{{last, ...v2} => {
if (last or C{...v2}) { if (last or C{...v2}) {
@ -219,7 +219,7 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
if ((m&(m-1)) != 0) { # More bits than one if ((m&(m-1)) != 0) { # More bits than one
# Filter out values equal to the previous, or first new # Filter out values equal to the previous, or first new
def pind = (iota{k}&15) - 1 def pind = (iota{k}&15) - 1
prev:= make{VI, each{bind{max,0}, pind}} prev:= make{VI, each{max{0, .}, pind}}
e:= ~homMask{v == VI**TI~~xi} e:= ~homMask{v == VI**TI~~xi}
e&= base{2,pind<0} | ~homMask{v == sel{[16]i8, v, prev}} e&= base{2,pind<0} | ~homMask{v == sel{[16]i8, v, prev}}
if (rbit) rv&= e | -m # Don't remove first bit if (rbit) rv&= e | -m # Don't remove first bit

View File

@ -72,27 +72,27 @@ def makeselx{VI, VD, nsel, xd, logv, cshuf} = {
def bs{b, c, x} = cshuf{x, c} def bs{b, c, x} = cshuf{x, c}
def bs{b, c, x & tuplen{b}>0} = { def bs{b, c, x & tuplen{b}>0} = {
tupsel{0,b}{each{bind{bs, slice{b,1}, c}, x}} tupsel{0,b}{each{bs{slice{b,1}, c, .}, x}}
} }
def i = iota{logv} def i = iota{logv}
def vs = each{bind{broadcast,VI}, nsel<<reverse{i}} def vs = each{broadcast{VI, .}, nsel<<reverse{i}}
{c} => VD~~bs{each{bb{c},i==0,vs}, c, xd} {c} => VD~~bs{each{bb{c},i==0,vs}, c, xd}
} }
def makeshuf{VI, VD, x0, logv} = { def makeshuf{VI, VD, x0, logv} = {
x:= *VD~~x0 x:= *VD~~x0
def halves{v} = each{bind{shuf, [4]u64, v}, tup{4b1010, 4b3232}} def halves{v} = each{shuf{[4]u64, v, .}, tup{4b1010, 4b3232}}
def readx{l,o} = each{bind{readx,l-1}, o + iota{2}<<(l-2)} def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-2)}
def readx{l==0,o} = shuf{[4]u64, load{x}, 4b1010} def readx{l==0,o} = shuf{[4]u64, load{x}, 4b1010}
def readx{l==1,o} = halves{load{x, o}} def readx{l==1,o} = halves{load{x, o}}
xd:= readx{logv, 0} xd:= readx{logv, 0}
makeselx{VI,VD,16,xd,logv, bind{sel,[16]i8}} makeselx{VI,VD,16,xd,logv, sel{[16]i8, ...}}
} }
def makeperm{VI, VD, x0, logv} = { def makeperm{VI, VD, x0, logv} = {
x:= *VD~~x0 x:= *VD~~x0
def readx{l,o} = each{bind{readx,l-1}, o + iota{2}<<(l-1)} def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-1)}
def readx{l==0,o} = load{x, o} def readx{l==0,o} = load{x, o}
makeselx{[8]i32,VD,8, readx{logv, 0}, logv, bind{sel,[8]i32}} makeselx{[8]i32,VD,8, readx{logv, 0}, logv, sel{[8]i32, ...}}
} }
fn select{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = { fn select{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {

View File

@ -8,7 +8,7 @@ if (hasarch{'AVX512F'}) {
local def mti{s,T} = merge{'_mm512_',s,'_epi',fmtnat{elwidth{T}}} local def mti{s,T} = merge{'_mm512_',s,'_epi',fmtnat{elwidth{T}}}
def load{a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n} def load{a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n}
def make{T, xs & 512==width{T} & tuplen{xs}==vcount{T}} = { def make{T, xs & 512==width{T} & tuplen{xs}==vcount{T}} = {
def p = each{{c}=>promote{eltype{T},c},reverse{xs}} def p = each{promote{eltype{T},.}, reverse{xs}}
emit{T, mti{'set',T}, ...p} emit{T, mti{'set',T}, ...p}
} }
def iota{T & isvec{T} & 512==width{T}} = make{T, iota{vcount{T}}} def iota{T & isvec{T} & 512==width{T}} = make{T, iota{vcount{T}}}
@ -31,7 +31,7 @@ def maketab{l,w,s,G} = {
reverse{iota{l}<<s} reverse{iota{l}<<s}
} }
# Store popcnt-1 in the high element # Store popcnt-1 in the high element
def top = (fold{bind{flat_table,+}, l**iota{2}} - 1)%(1<<(w-s)) def top = (fold{flat_table{+, ...}, l**iota{2}} - 1)%(1<<(w-s))
top<<(l*w-w+s) | bot # Overlaps for all-1 value only top<<(l*w-w+s) | bot # Overlaps for all-1 value only
} }
def maketab{l,w,s} = maketab{l,w,s,{x}=>x} def maketab{l,w,s} = maketab{l,w,s,{x}=>x}
@ -217,7 +217,7 @@ def thresh{c, T==i64 & hasarch{'AVX2'}} = 8
fn slash{c, T & hasarch{'AVX2'} & T>=i32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { fn slash{c, T & hasarch{'AVX2'} & T>=i32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T} def tw = width{T}
def V = [8]u32 def V = [8]u32
expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}} expander := make{[32]u8, merge{...each{tup{., ... 3**128}, iota{8}>>lb{tw/32}}}}
def from_ind = if (c) { def from_ind = if (c) {
i:u64 = 0 i:u64 = 0
{j} => { v:=load{*V~~x, i}; ++i; sel{V, v, j} } {j} => { v:=load{*V~~x, i}; ++i; sel{V, v, j} }

View File

@ -102,8 +102,8 @@ fn squeeze{vw, X, CHR, B}(x0:*void, len:Size) : u32 = {
def int = { def int = {
def {int, wdn} = { def {int, wdn} = {
if (hasarch{'AARCH64'} and tuplen{is}==2) { if (hasarch{'AARCH64'} and tuplen{is}==2) {
def intp = narrowPair{...each{{v}=>cvt{i64,v}, v0}} def intp = narrowPair{...each{cvt{i64,.}, v0}}
def wdn = each{{v}=>cvt{f64,v}, widen{intp}} def wdn = each{cvt{f64,.}, widen{intp}}
tup{intp, wdn} tup{intp, wdn}
} else { } else {
def ints = each{{v} => cvtNarrow{ty_s{E}, v}, v0} def ints = each{{v} => cvtNarrow{ty_s{E}, v}, v0}

View File

@ -39,7 +39,7 @@ def broadcast{T, v & w128f{T, 64}} = emit{T, '_mm_set1_pd', v}
def broadcast{T, v & w128f{T, 32}} = emit{T, '_mm_set1_ps', v} def broadcast{T, v & w128f{T, 32}} = emit{T, '_mm_set1_ps', v}
# make from elements # make from elements
local def makeGen{T,s,x} = emit{T, s, ...each{{c}=>promote{eltype{T},c}, x}} local def makeGen{T,s,x} = emit{T, s, ...each{promote{eltype{T},.}, x}}
def make{T, ...xs & w128f{T,64} & tuplen{xs}== 2} = makeGen{T, '_mm_setr_pd', xs} def make{T, ...xs & w128f{T,64} & tuplen{xs}== 2} = makeGen{T, '_mm_setr_pd', xs}
def make{T, ...xs & w128f{T,32} & tuplen{xs}== 4} = makeGen{T, '_mm_setr_ps', xs} def make{T, ...xs & w128f{T,32} & tuplen{xs}== 4} = makeGen{T, '_mm_setr_ps', xs}
def make{T, ...xs & w128i{T,64} & tuplen{xs}== 2} = makeGen{T, '_mm_set_epi64x', tup{tupsel{1,xs}, tupsel{0,xs}}} def make{T, ...xs & w128i{T,64} & tuplen{xs}== 2} = makeGen{T, '_mm_set_epi64x', tup{tupsel{1,xs}, tupsel{0,xs}}}

View File

@ -112,7 +112,7 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
assert{k == kh} assert{k == kh}
def VT = [k]T def VT = [k]T
def line_vecs = line_bytes / (width{VT}/8) def line_vecs = line_bytes / (width{VT}/8)
def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs} def store_line{p, vs} = each{store{p, ...}, iota{line_vecs}, vs}
def get_lines{loadx} = { def get_lines{loadx} = {
def vt{i} = transpose_square{VT, k, each{loadx, k*i + iota{k}}} def vt{i} = transpose_square{VT, k, each{loadx, k*i + iota{k}}}
each{tup, ...each{vt, iota{line_vecs}}} each{tup, ...each{vt, iota{line_vecs}}}