use partial application

This commit is contained in:
dzaima 2023-07-22 18:19:31 +03:00
parent 033f3de6b9
commit 959614c785
17 changed files with 43 additions and 44 deletions

View File

@ -34,7 +34,7 @@ def broadcast{T, v & w256i{T, 64}} = emit{T, '_mm256_set1_epi64x',promote{eltype
def broadcast{T, v & w256f{T, 64}} = emit{T, '_mm256_set1_pd', v}
def broadcast{T, v & w256f{T, 32}} = emit{T, '_mm256_set1_ps', v}
local def makeGen{T,s,x} = emit{T, s, ...each{{c}=>promote{eltype{T},c}, x}}
local def makeGen{T,s,x} = emit{T, s, ...each{promote{eltype{T}, .}, x}}
def make{T, ...xs & w256f{T,64} & tuplen{xs}== 4} = makeGen{T, '_mm256_setr_pd', xs}
def make{T, ...xs & w256f{T,32} & tuplen{xs}== 8} = makeGen{T, '_mm256_setr_ps', xs}
def make{T, ...xs & w256i{T,64} & tuplen{xs}== 4} = makeGen{T, '_mm256_setr_epi64x', xs}

View File

@ -31,7 +31,7 @@ def loadu{p:T & elwidth{T}==8} = load{p}
def storeu{p:T, v:eltype{T} & elwidth{T}==8} = store{p, v}
def reinterpret{T, x:X & T==X} = x
def exportN{f, ...ns} = each{{n} => export{n, f}, ns}
def exportN{f, ...ns} = each{export{.,f}, ns}
def exportT{name, fs} = { v:*type{tupsel{0,fs}} = fs; export{name, v} }
@ -220,7 +220,7 @@ def inRangeExcl{x:T, start, end} = inRangeLen{x, start, end-start} # ∊ [start;
def broadcast{T, v & isprim{T}} = v
def iota{n & knum{n}} = range{n}
def collect{vars,begin,end,iter & knum{begin} & knum{end}} = {
each{{i} => iter{i, vars}, range{end-begin}+begin}
each{iter{., vars}, range{end-begin}+begin}
}
def broadcast{n, v & knum{n}} = each{{_}=>v, range{n}}

View File

@ -65,16 +65,15 @@ fn max_scan{T, up}(x:*T, len:u64) : void = {
def getsel{...x} = assert{'shuffling not supported', show{...x}}
if (hasarch{'AVX2'}) {
def getsel{h:H & lvec{H, 16, 8}} = {
v := pair{h,h}
{i} => sel{H, v, i}
sel{H, pair{h,h}, .}
}
def getsel{v:V & lvec{V, 32, 8}} = {
def H = n_h{V}
vtop := V**(vcount{V}/2)
hs := each{bind{shuf, [4]u64, v}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{{h}=>sel{H,h,i}, hs}, V~~i<vtop}
hs := each{shuf{[4]u64, v, .}, tup{4b3232, 4b1010}}
{i} => homBlend{...each{sel{H,.,i}, hs}, V~~i<vtop}
}
def getsel{v:V & lvec{V, 8, 32}} = { {i} => sel{V, v, i} }
def getsel{v:V & lvec{V, 8, 32}} = sel{V, v, .}
}
# Move evens to half 0 and odds to half 1
@ -176,11 +175,11 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done & hasarch{'AVX2'}} = {
# We'll subtract 1 when indexing so the initial 0 isn't needed
tui:*i8 = copy{maxu, 0}; i:T = 0
@for (tui over promote{u64,nu}) { i = load{t, load{w, i}}; tui = i }
def tv = bind{load, *V~~tui}
def tv = load{*V~~tui, .}
ui = tv{0}
if (nu > 16) ui1 = shuf{[4]u64, ui, 4b3232}
ui = shuf{[4]u64, ui, 4b1010}
if (nu > vl) ui2 = each{bind{shuf, [4]u64, tv{1}}, tup{4b1010, 4b3232}}
if (nu > vl) ui2 = each{shuf{[4]u64, tv{1}, .}, tup{4b1010, 4b3232}}
}
# Popcount on 8-bit values
def sums{n} = if (n==1) tup{0} else { def s=sums{n/2}; merge{s,s+1} }
@ -229,7 +228,7 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
def wd = width{T}
def I = if (wd<32) u8 else u32; def wi = width{I}
def lanes = hasarch{'AVX2'} & (I==u8)
def isub = wd/wi; def bb = bind{base,1<<wi}
def isub = wd/wi; def bb = base{1<<wi, .}
def vl = 256/wd; def svl = vl>>lanes
def V = [vl]T
def U = [vl](ty_u{T})
@ -250,7 +249,7 @@ def bin_search_vec{T, up, w:*T, wn, x:*T, xn, rp, maxwn & hasarch{'AVX2'}} = {
if (ex>=1 and wn >= svl) {
--gap # Allows subtracting < instead of adding <=
def un = uninterleave
def tr_half{a, b} = each{bind{shufHalves,a,b}, tup{16b20, 16b31}}
def tr_half{a, b} = each{shufHalves{a,b,.}, tup{16b20, 16b31}}
def un{{a,b}} = tr_half{un{a},un{b}}
if (not lanes) tupsel{1,wv} = load{wg, 1}
wv = un{wv}
@ -313,7 +312,7 @@ def bin_search_branchless{up, w, wn, x, n, res, rtype} = {
l0 := wn + 1
# Take a list of indices in x/res to allow unrolling
def search{inds} = {
xs:= each{bind{load,x}, inds} # Values
xs:= each{load{x,.}, inds} # Values
ss:= each{{_}=>ws, inds} # Initial lower bound
l := l0; h := undefined{u64} # Interval size l, same for all values
while ((h=l/2) > 0) {
@ -380,6 +379,6 @@ exportT{
'si_saturate',
each{{a}=>saturate{...a}, merge{
tup{tup{i16,i8}, tup{i32,i8}, tup{i32,i16}},
join{table{bind{tup,f64}, tup{i8,i16,i32}, tup{1,0}}}
join{table{tup{f64, ...}, tup{i8,i16,i32}, tup{1,0}}}
}}
}

View File

@ -99,5 +99,5 @@ def loadBatchBit{T, x:*u64, is & ktup{is}} = {
# assert{count*len <= 64}
# bits:= b_getBatchLo{count*len, x, tupsel{0,is}}
# @collect(i to len) spreadBits{T, truncBits{count, bits>>(i*count)}}
each{{i} => loadBatchBit{T, x, i}, is}
each{loadBatchBit{T, x, .}, is}
}

View File

@ -26,6 +26,6 @@ fn bitsel_i{VL,T}(r:*void, bits:*u64, e0:u64, e1:u64, len:u64) : void = {
bitsel{VL, T, *T~~r, bits, trunc{T,e0}, trunc{T,e1}, len}
}
def table{w} = each{{T} => bitsel_i{w, T}, tup{u8, u16, u32, u64}}
def table{w} = each{bitsel_i{w, .}, tup{u8, u16, u32, u64}}
exportT{'simd_bitsel', table{arch_defvw}}

View File

@ -90,13 +90,13 @@ def any2bit{VT, unr, op, wS, wV, xS, xV, dst:*u64, len:(Size)} = {
fn aa2bit{VT, unr, op}(dst:*u64, wr:*void, xr:*void, len:Size) : void = {
wv:= *VT~~wr; ws:= *eltype{VT}~~wr
xv:= *VT~~xr; xs:= *eltype{VT}~~xr
any2bit{VT, unr, op, {i}=>load{ws,i}, {i}=>load{wv,i}, {i}=>load{xs,i}, {i}=>load{xv,i}, dst, len}
any2bit{VT, unr, op, load{ws,.}, load{wv,.}, load{xs,.}, load{xv,.}, dst, len}
}
fn as2bit{VT, unr, op}(dst:*u64, wr:*void, x:u64, len:Size) : void = {
wv:= *VT~~wr; ws:= *eltype{VT}~~wr
xv:= VT**pathAS{dst, len, eltype{VT}, op, x}
any2bit{VT, unr, op, {i}=>load{ws,i}, {i}=>load{wv,i}, {i}=>x, {i}=>xv, dst, len}
any2bit{VT, unr, op, load{ws,.}, load{wv,.}, {i}=>x, {i}=>xv, dst, len}
}
fn bitAA{bitop}(dst:*u64, wr:*void, xr:*void, len:Size) : void = {

View File

@ -39,7 +39,7 @@ rcsh_data:*i8 = join{join{each{get_shuf_data, rcsh_vals}}}
# first 4 shuffle vectors for 11≤𝕨≤61; only uses the low half of the input
def rcsh4_dom = replicate{bind{>=,64}, replicate{fact_tab==1, fact_inds}}
rcsh4_dat:*i8 = join{join{each{{wv}=>get_shuf_data{wv, 4}, rcsh4_dom}}}
rcsh4_dat:*i8 = join{join{each{get_shuf_data{., 4}, rcsh4_dom}}}
rcsh4_lkup:*i8 = shiftright{0, scan{+, fold{|, table{==, rcsh4_dom, iota{64}}}}}
def read_shuf_vecs{l, ellw:u64, shp:P} = { # tuple of byte selectors in 1<<ellw
@ -57,7 +57,7 @@ def read_shuf_vecs{l, ellw:u64, shp:P} = { # tuple of byte selectors in 1<<ellw
def sh = each{{v}=>{r:=v}, l**V**0}
def tlen{e} = cdiv{l, e} # Length for e bytes, rounded up
def set{i} = { tupsel{i,sh} = each{bind{load,shp},i} }
def set{i} = { tupsel{i,sh} = each{load{shp,.}, i} }
def ext{e} = {
def m = tlen{2*e}; def n = tlen{e} # m<n
if (ellw <= lb{e}) set{slice{iota{n},m}}
@ -94,9 +94,9 @@ if (hasarch{'AVX2'}) {
def l = tuplen{sh}
def h = l>>1
def fs{v, s} = gen{sel{[16]i8, v, s}}
a := shuf{[4]u64, x, 4b1010}; each{bind{fs,a}, slice{sh,0,h}}
a := shuf{[4]u64, x, 4b1010}; each{fs{a,.}, slice{sh,0,h}}
if (l%2) fs{x, tupsel{h, sh}}
b := shuf{[4]u64, x, 4b3232}; each{bind{fs,b}, slice{sh,-h}}
b := shuf{[4]u64, x, 4b3232}; each{fs{b,.}, slice{sh,-h}}
}
def get_rep_iter{V, wv==2}{x, gen} = {
@ -105,7 +105,7 @@ if (hasarch{'AVX2'}) {
}
def get_rep_iter{V==[4]u64, wv} = {
def step = 4
def sh = each{bind{base,4}, get_shufs{step, wv, wv}}
def sh = each{base{4,.}, get_shufs{step, wv, wv}}
{x, gen} => each{{s}=>gen{shuf{V, x, s}}, sh}
}

View File

@ -48,7 +48,7 @@ fn copy{X, R}(x: *void, r: *void, l:u64, xRaw: *void) : void = {
@maskedLoop{vcount{V64}}(sr in tup{'g',rp}, x in tup{V64,xp} over cdiv{l,64}) sr{x}
} else if (X==u1) {
# show{'X==u1', X, R}
copyFromBits{[bulk]R, {T, i} => loadBatchBit{T, xp, i}, r, l}
copyFromBits{[bulk]R, loadBatchBit{., xp, .}, r, l}
} else if (R==u1) {
# show{'R==u1', X, R}
def XU = ty_u{XV}

View File

@ -116,7 +116,7 @@ def runner{u, R, F} = {
}
# homAny, topAny already give masked vals; anyne doesn't, and ~andAllZero assumes no masking
def runChecks_any{F, vals} = { F{tree_fold{|, each{{c}=>tupsel{1,c}, vals}}} }
def runChecks_any{F, vals} = { F{tree_fold{|, each{tupsel{1,.}, vals}}} }
def runChecks{type=='homAny', vals, M} = runChecks_any{homAny, vals}
def runChecks{type=='topAny', vals, M} = runChecks_any{topAny, vals}
def runChecks{type=='none', vals, M} = 0
@ -133,7 +133,7 @@ def runChecks{type=='anyne', vals, M} = {
def arithProcess{F, run, overflow, M, is, cw, cx, TY} = {
def {values, checks} = flip{each{{w1, x1} => run{F, M, w1, x1}, cw, cx}}
def ctype = oneVal{each{{c}=>tupsel{0,c}, checks}}
def ctype = oneVal{each{tupsel{0,.}, checks}}
if (rare{runChecks{ctype, checks, M}}) overflow{tupsel{0,is}*vcount{TY}}
each{{c} => TY~~c, values}
}

View File

@ -165,7 +165,7 @@ def broadcast{T, x & nvec{T}} = emit{T, ntyp{'vdup', '_n', T}, x}
def make{T, ...xs & nvec{T} & tuplen{xs}==vcount{T}} = {
def TE = eltype{T}
load{*T ~~ *TE ~~ each{{c}=>promote{eltype{T},c}, xs}, 0}
load{*T ~~ *TE ~~ each{promote{eltype{T},.}, xs}, 0}
}
def make{T, x & nvec{T} & istup{x}} = make{T, ...x}
def iota{T & nvec{T}} = make{T, ...iota{vcount{T}}}

View File

@ -89,7 +89,7 @@ def scan_assoc{op, a:T} = {
l:= (type{b}~~make{[8]i32,0,0,0,-1,0,0,0,0}) & spread{b}
op{b, sel{[8]i32, l, make{[8]i32,0,0,0,0, 3,3,3,3}}}
}
def scan_plus = bind{scan_assoc, +}
def scan_plus = scan_assoc{+, .}
# Associative scan
fn avx2_scan_assoc_0{T, op}(x:*T, r:*T, len:u64, init:T) : void = {

View File

@ -17,7 +17,7 @@ def ctzi{x} = promote{u64, ctz{x}} # Count trailing zeros, as index
def findFirst{C, M, F, ...v1} = {
def exit = makelabel{}
def args = undef{M{...each{{c}=>tupsel{0,c}, v1}}}
def args = undef{M{...each{tupsel{0, .}, v1}}}
def am = tuplen{tupsel{0,v1}}
each{{last, ...v2} => {
if (last or C{...v2}) {
@ -219,7 +219,7 @@ def do_bittab{x0:*void, n:u64, tab:*void, u:u8, t, mode, r0} = {
if ((m&(m-1)) != 0) { # More bits than one
# Filter out values equal to the previous, or first new
def pind = (iota{k}&15) - 1
prev:= make{VI, each{bind{max,0}, pind}}
prev:= make{VI, each{max{0, .}, pind}}
e:= ~homMask{v == VI**TI~~xi}
e&= base{2,pind<0} | ~homMask{v == sel{[16]i8, v, prev}}
if (rbit) rv&= e | -m # Don't remove first bit

View File

@ -72,27 +72,27 @@ def makeselx{VI, VD, nsel, xd, logv, cshuf} = {
def bs{b, c, x} = cshuf{x, c}
def bs{b, c, x & tuplen{b}>0} = {
tupsel{0,b}{each{bind{bs, slice{b,1}, c}, x}}
tupsel{0,b}{each{bs{slice{b,1}, c, .}, x}}
}
def i = iota{logv}
def vs = each{bind{broadcast,VI}, nsel<<reverse{i}}
def vs = each{broadcast{VI, .}, nsel<<reverse{i}}
{c} => VD~~bs{each{bb{c},i==0,vs}, c, xd}
}
def makeshuf{VI, VD, x0, logv} = {
x:= *VD~~x0
def halves{v} = each{bind{shuf, [4]u64, v}, tup{4b1010, 4b3232}}
def readx{l,o} = each{bind{readx,l-1}, o + iota{2}<<(l-2)}
def halves{v} = each{shuf{[4]u64, v, .}, tup{4b1010, 4b3232}}
def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-2)}
def readx{l==0,o} = shuf{[4]u64, load{x}, 4b1010}
def readx{l==1,o} = halves{load{x, o}}
xd:= readx{logv, 0}
makeselx{VI,VD,16,xd,logv, bind{sel,[16]i8}}
makeselx{VI,VD,16,xd,logv, sel{[16]i8, ...}}
}
def makeperm{VI, VD, x0, logv} = {
x:= *VD~~x0
def readx{l,o} = each{bind{readx,l-1}, o + iota{2}<<(l-1)}
def readx{l,o} = each{readx{l-1, .}, o + iota{2}<<(l-1)}
def readx{l==0,o} = load{x, o}
makeselx{[8]i32,VD,8, readx{logv, 0}, logv, bind{sel,[8]i32}}
makeselx{[8]i32,VD,8, readx{logv, 0}, logv, sel{[8]i32, ...}}
}
fn select{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {

View File

@ -8,7 +8,7 @@ if (hasarch{'AVX512F'}) {
local def mti{s,T} = merge{'_mm512_',s,'_epi',fmtnat{elwidth{T}}}
def load{a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n}
def make{T, xs & 512==width{T} & tuplen{xs}==vcount{T}} = {
def p = each{{c}=>promote{eltype{T},c},reverse{xs}}
def p = each{promote{eltype{T},.}, reverse{xs}}
emit{T, mti{'set',T}, ...p}
}
def iota{T & isvec{T} & 512==width{T}} = make{T, iota{vcount{T}}}
@ -31,7 +31,7 @@ def maketab{l,w,s,G} = {
reverse{iota{l}<<s}
}
# Store popcnt-1 in the high element
def top = (fold{bind{flat_table,+}, l**iota{2}} - 1)%(1<<(w-s))
def top = (fold{flat_table{+, ...}, l**iota{2}} - 1)%(1<<(w-s))
top<<(l*w-w+s) | bot # Overlaps for all-1 value only
}
def maketab{l,w,s} = maketab{l,w,s,{x}=>x}
@ -217,7 +217,7 @@ def thresh{c, T==i64 & hasarch{'AVX2'}} = 8
fn slash{c, T & hasarch{'AVX2'} & T>=i32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def V = [8]u32
expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}}
expander := make{[32]u8, merge{...each{tup{., ... 3**128}, iota{8}>>lb{tw/32}}}}
def from_ind = if (c) {
i:u64 = 0
{j} => { v:=load{*V~~x, i}; ++i; sel{V, v, j} }

View File

@ -102,8 +102,8 @@ fn squeeze{vw, X, CHR, B}(x0:*void, len:Size) : u32 = {
def int = {
def {int, wdn} = {
if (hasarch{'AARCH64'} and tuplen{is}==2) {
def intp = narrowPair{...each{{v}=>cvt{i64,v}, v0}}
def wdn = each{{v}=>cvt{f64,v}, widen{intp}}
def intp = narrowPair{...each{cvt{i64,.}, v0}}
def wdn = each{cvt{f64,.}, widen{intp}}
tup{intp, wdn}
} else {
def ints = each{{v} => cvtNarrow{ty_s{E}, v}, v0}

View File

@ -39,7 +39,7 @@ def broadcast{T, v & w128f{T, 64}} = emit{T, '_mm_set1_pd', v}
def broadcast{T, v & w128f{T, 32}} = emit{T, '_mm_set1_ps', v}
# make from elements
local def makeGen{T,s,x} = emit{T, s, ...each{{c}=>promote{eltype{T},c}, x}}
local def makeGen{T,s,x} = emit{T, s, ...each{promote{eltype{T},.}, x}}
def make{T, ...xs & w128f{T,64} & tuplen{xs}== 2} = makeGen{T, '_mm_setr_pd', xs}
def make{T, ...xs & w128f{T,32} & tuplen{xs}== 4} = makeGen{T, '_mm_setr_ps', xs}
def make{T, ...xs & w128i{T,64} & tuplen{xs}== 2} = makeGen{T, '_mm_set_epi64x', tup{tupsel{1,xs}, tupsel{0,xs}}}

View File

@ -112,7 +112,7 @@ def transpose_with_kernel{T, k, kh, call_base, rp:*T, xp:*T, w, h, ws, hs} = {
assert{k == kh}
def VT = [k]T
def line_vecs = line_bytes / (width{VT}/8)
def store_line{p, vs} = each{bind{store,p}, iota{line_vecs}, vs}
def store_line{p, vs} = each{store{p, ...}, iota{line_vecs}, vs}
def get_lines{loadx} = {
def vt{i} = transpose_square{VT, k, each{loadx, k*i + iota{k}}}
each{tup, ...each{vt, iota{line_vecs}}}