Merge 3 to 7 replicate float with other types: shuffle instead of permute

This commit is contained in:
Marshall Lochbaum 2022-09-29 20:12:45 -04:00
parent 29886f355a
commit 582afe33c9

View File

@ -62,7 +62,7 @@ def makefact{divisor, range} = {
def t = table{{a,b}=>0==b%a, divisor, range}
fold{+, 1, reverse{scan{|, reverse{t}}}}
}
def basic_rep = incl{2,7}
def basic_rep = incl{2, 7}
def fact_size = 128
def fact_inds = slice{iota{fact_size},8}
def fact_tab = makefact{basic_rep, fact_inds}
@ -108,7 +108,7 @@ def read_shuf_vecs{l, elbytes:u64, shp:*[32]i8} = {
if (elbytes<=e) set{slice{iota{n},m}}
else slice{sh,0,n} = doubles{n,slice{sh,0,m}}
}
set{iota{tlen{4}}}; ext{2}; ext{1}
set{iota{tlen{8}}}; ext{4}; ext{2}; ext{1}
sh
}
@ -135,7 +135,7 @@ def rep_const_shuffle{V, wv, onreps, xv:*V, rv:*V, n:u64} = {
}
def rep_const_shuffle{V, wv, xv:*V, rv:*V, n:u64} = rep_const_shuffle{V, wv, get_rep_iter{V, wv}, xv, rv, n}
def rcsh_vals = incl{3, 7}
def rcsh_vals = slice{basic_rep, 1} # Handle 2 specially
rcsh_offs:*u8 = shiftright{0, scan{+,rcsh_vals}}
rcsh_data:*i8 = join{join{each{get_shuf_data, rcsh_vals}}}
rcsh_sub{wv}(elbytes:u64, x:*i8, r:*i8, n:u64, sh:*[32]i8) : void = {
@ -143,10 +143,6 @@ rcsh_sub{wv}(elbytes:u64, x:*i8, r:*i8, n:u64, sh:*[32]i8) : void = {
def st = read_shuf_vecs{wv, elbytes, sh}
rep_const_shuffle{V, wv, rep_iter_from_sh{st}, *V~~x, *V~~r, n}
}
rep_const_shuffle_full(wv:i32, eb:u64, x:*i8, r:*i8, n:u64, sh:*[32]i8) : void = {
def try{k} = { if (wv==k) rcsh_sub{k}(eb, x, r, n, sh) }
each{try, rcsh_vals}
}
def rcsh4_dom = replicate{bind{>=,64}, replicate{fact_tab==1, fact_inds}}
rcsh4_dat:*i8 = join{join{each{{wv}=>get_shuf_data{wv, 4}, rcsh4_dom}}}
@ -181,6 +177,17 @@ rep_const_shuffle_partial4(wv:u64, elbytes:u64, x:*i8, r:*i8, n:u64) : void = {
if (q) maskstoreF{*V~~r, maskOf{V, q}, 0, s}
}
rep_const_shuffle_any(wv:i32, elbytes:u64, x:*i8, r:*i8, n:u64) : void = {
if (wv > tupsel{-1,rcsh_vals}) {
return{rep_const_shuffle_partial4(wv, elbytes, x, r, n)}
}
n *= elbytes
ri := wv - tupsel{0,rcsh_vals}
sh := *[32]i8~~rcsh_data + load{rcsh_offs,ri}
def try{k} = { if (wv==k) rcsh_sub{k}(elbytes, x, r, n, sh) }
each{try, rcsh_vals}
}
def rep_const_broadcast{T, kv, loop, wv:u64, x:*T, r:*T, n:u64} = {
assert{kv > 0}
def V = [256/width{T}]T
@ -215,20 +222,8 @@ rep_const{T}(wv:i32, x:*void, r:*void, n:u64) : void = {
def specialize{k} = {
if (wv==k) return{rep_const_shuffle{V, k, *V~~x, *V~~r, n}}
}
if (wT<=32) {
def elbytes = wT/8
specialize{2}
if (wv <= tupsel{-1,rcsh_vals}) {
ri := wv - tupsel{0,rcsh_vals}
shp:= *[32]i8~~rcsh_data + load{rcsh_offs,ri}
rep_const_shuffle_full(wv, elbytes, x, r, n*elbytes, shp)
} else {
rep_const_shuffle_partial4(wv, elbytes, x, r, n)
}
} else {
assert{max_shuffle <= tupsel{0, fact_inds}}
each{specialize, basic_rep}
}
specialize{2}
rep_const_shuffle_any(wv, wT/8, x, r, n)
} else {
kv := wv / vn
@unroll (k from (max_shuffle/vn) to 4) {