321 lines
9.8 KiB
Plaintext
321 lines
9.8 KiB
Plaintext
include './base'
|
|
include './sse'
|
|
include './clmul'
|
|
include './avx'
|
|
include './avx2'
|
|
include './mask'
|
|
include './f64'
|
|
|
|
# Initialized scan, generic implementation
|
|
fn scan_scal{T, op}(x:*T, r:*T, len:u64, m:T) : void = {
|
|
@for (x, r over len) r = m = op{m, x}
|
|
}
|
|
|
|
def sel8{v:V, t} = sel{[16]u8, v, make{re_el{i8,V}, t}}
|
|
def sel8{v:V, t & w256{V} & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}}
|
|
|
|
def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}}
|
|
|
|
# Fill last 4 bytes with last element, in each lane
|
|
def spread{a:VT} = {
|
|
def w = elwidth{VT}
|
|
def b = w/8
|
|
if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a
|
|
}
|
|
|
|
# Set all elements with the last element of the input
|
|
def toLast{n:VT & hasarch{'X86_64'} & w128{VT}} = {
|
|
def l{v, w} = l{zipHi{v,v}, 2*w}
|
|
def l{v, w & w>=32} = shuf{[4]i32, v, 4**3}
|
|
l{n, elwidth{VT}}
|
|
}
|
|
def toLast{n:VT & hasarch{'AVX2'} & w256{VT}} = {
|
|
if (elwidth{VT}<=32) sel{[8]i32, spread{n}, [8]i32**7}
|
|
else shuf{[4]u64, n, 4**3}
|
|
}
|
|
|
|
def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = {
|
|
def step = arch_defvw/width{T}
|
|
def V = [step]T
|
|
p:= V**init
|
|
xv:= *V ~~ x
|
|
rv:= *V ~~ r
|
|
e:= len/step
|
|
@for (xv, rv over e) rv = scan{xv,p}
|
|
q:= len & (step-1)
|
|
if (q) homMaskStoreF{rv+e, maskOf{V, q}, scan_last{load{xv,e}, p}}
|
|
}
|
|
def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = {
|
|
def last{v, p} = op{pre{v}, p}
|
|
def scan{v, p} = {
|
|
n:= last{v, p}
|
|
p = toLast{n}
|
|
n
|
|
}
|
|
scan_loop{T, init, x, r, len, scan, last}
|
|
}
|
|
|
|
# Make prefix scan from op and shifter by applying the operation
|
|
# at increasing power-of-two shifts
|
|
def prefix_byshift{op, sh} = {
|
|
def pre{v:V, k} = if (k < width{V}) pre{op{v, sh{v,k}}, 2*k} else v
|
|
{v:T} => pre{v, if (isvec{T}) elwidth{T} else 1}
|
|
}
|
|
|
|
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
|
|
def scan_idem = scan_scal
|
|
fn scan_idem{T, op & hasarch{'SSE4.1'}}(x:*T, r:*T, len:u64, init:T) : void = {
|
|
# Within each lane, scan using shifts by powers of 2. First k elements
|
|
# when shifting by k don't need to change, so leave them alone.
|
|
def shift{k,l} = merge{iota{k},iota{l-k}}
|
|
def shb{v, k} = sel8{v, shift{k/8,16}}
|
|
def shb{v, k & k>=32} = shuf{[4]u32, v, shift{k/32,4}}
|
|
def shb{v, k & k==128 & hasarch{'AVX2'}} = {
|
|
# After lanewise scan, broadcast end of lane 0 to entire lane 1
|
|
sel{[8]i32, spread{v}, make{[8]i32, 3*(3<iota{8})}}
|
|
}
|
|
scan_post{T, init, x, r, len, op, prefix_byshift{op, shb}}
|
|
}
|
|
fn scan_idem{T==f64, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
|
|
def sc{a} = op{a, zipLo{a,a}}
|
|
def sc{a & hasarch{'AVX2'}} = {
|
|
def sh{s, a} = op{a, shuf{[4]u64, a, s}}
|
|
sh{4b1110,sh{4b2200,a}}
|
|
}
|
|
scan_post{T, init, x, r, len, op, sc}
|
|
}
|
|
|
|
export{'si_scan_min_init_i8', scan_idem{i8 , min}}; export{'si_scan_max_init_i8', scan_idem{i8 , max}}
|
|
export{'si_scan_min_init_i16', scan_idem{i16, min}}; export{'si_scan_max_init_i16', scan_idem{i16, max}}
|
|
export{'si_scan_min_init_i32', scan_idem{i32, min}}; export{'si_scan_max_init_i32', scan_idem{i32, max}}
|
|
export{'si_scan_min_init_f64', scan_idem{f64, min}}; export{'si_scan_max_init_f64', scan_idem{f64, max}}
|
|
|
|
fn scan_idem_id{T, op}(x:*T, r:*T, len:u64) : void = {
|
|
scan_idem{T, op}(x, r, len, (if (same{op,min}) maxvalue else minvalue){T})
|
|
}
|
|
export{'si_scan_min_i8', scan_idem_id{i8 , min}}; export{'si_scan_max_i8', scan_idem_id{i8 , max}}
|
|
export{'si_scan_min_i16', scan_idem_id{i16, min}}; export{'si_scan_max_i16', scan_idem_id{i16, max}}
|
|
export{'si_scan_min_i32', scan_idem_id{i32, min}}; export{'si_scan_max_i32', scan_idem_id{i32, max}}
|
|
|
|
# Assumes identity is 0
|
|
def scan_assoc{op} = {
|
|
def shl0{v, k} = shl{[16]u8, v, k/8} # Lanewise
|
|
def shl0{v:V, k==128 & hasarch{'AVX2'}} = {
|
|
# Broadcast end of lane 0 to entire lane 1
|
|
l:= V~~make{[8]i32,0,0,0,-1,0,0,0,0} & spread{v}
|
|
sel{[8]i32, l, make{[8]i32, 3*(3<iota{8})}}
|
|
}
|
|
prefix_byshift{op, shl0}
|
|
}
|
|
def scan_plus = scan_assoc{+}
|
|
|
|
# Associative scan
|
|
def scan_assoc_0 = scan_scal
|
|
fn scan_assoc_0{T, op & hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
|
|
# Prefix op on entire AVX register
|
|
scan_post{T, init, x, r, len, op, scan_plus}
|
|
}
|
|
export{'si_scan_pluswrap_u8', scan_assoc_0{u8 , +}}
|
|
export{'si_scan_pluswrap_u16', scan_assoc_0{u16, +}}
|
|
export{'si_scan_pluswrap_u32', scan_assoc_0{u32, +}}
|
|
|
|
# xor scan
|
|
fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = {
|
|
@for (x, r over nw) {
|
|
r = p ^ prefix_byshift{^, <<}{x}
|
|
p = -(r>>63) # repeat sign bit
|
|
}
|
|
}
|
|
fn clmul_scan_ne_any{..._ & hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
|
|
def V = [2]u64
|
|
m := V**mark
|
|
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
|
|
p := clmul{a, m, i}
|
|
t := shr{[16]u8, p, 8}
|
|
s := p ^ carry
|
|
carry = s ^ t
|
|
s
|
|
}
|
|
xv := *V ~~ x
|
|
rv := *V ~~ r
|
|
e := words/2
|
|
c := V**init
|
|
@for (rv, xv over e) {
|
|
rv = apply{zipLo, (@collect (j to 2) xor64{xv, j, c})}
|
|
}
|
|
if (words & 1) {
|
|
storeLow{rv+e, 64, clmul{loadLow{xv+e, 64}, m, 0} ^ c}
|
|
}
|
|
}
|
|
fn scan_neq{..._ & hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
|
|
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
|
|
}
|
|
export{'si_scan_ne', scan_neq{}}
|
|
|
|
# Boolean cumulative sum
|
|
fn bcs{T}(x:*u64, r:*T, l:u64) : void = {
|
|
def bitp_get{arr, n} = (load{arr, n>>6} >> (n&63)) & 1
|
|
c:T = 0
|
|
@for (r over i to l) { c+= cast_i{T, bitp_get{x,i}}; r = c }
|
|
}
|
|
fn bcs{T & hasarch{'AVX2'}}(x:*u64, r:*T, l:u64) : void = {
|
|
def U = ty_u{T}
|
|
def w = width{T}
|
|
def vl= 256 / w
|
|
def V = [vl]U
|
|
rv:= *V~~r
|
|
xv:= *u32~~x
|
|
c:= V**0
|
|
|
|
def ii32 = iota{32}; def bit{k}=bit{k,ii32}; def tail{k}=tail{k,ii32}
|
|
def sums{n} = (if (n==0) tup{0}; else { def s=sums{n-1}; merge{s,s+1} })
|
|
def widen{v:T} = unpackQ{shuf{[4]u64, v, 4b3120}, T**0}
|
|
|
|
def sumlanes{x:u32} = {
|
|
b:= [8]u32**x >> make{[8]u32, 4*tail{1, iota{8}}}
|
|
s:= sel8{[32]u8~~b, ii32>>3 + bit{2}}
|
|
p:= s & make{[32]u8, (1<<(1+tail{2})) - 1} # Prefixes
|
|
d:= sel{[16]u8, make{[32]u8, merge{sums{4},sums{4}}}, [32]i8~~p}
|
|
d+= sel8{d, bit{2}*(1+bit{3}>>2)-1}
|
|
d + sel8{d, bit{3}-1}
|
|
}
|
|
def step{x:u32, i, store1} = {
|
|
d:= sumlanes{x}
|
|
if (w==8) d+= [32]u8~~shuf{[4]u64, [8]i32~~sel8{d, bit{3}<<4-1}, 4b1100}
|
|
j:= (w/8)*i
|
|
def out{v, k} = each{out, widen{v}, 2*k+iota{2}}
|
|
def out{v0:[vl]T, k} = {
|
|
v := V~~v0 + c
|
|
# Update carry at the lane boundary
|
|
if (w!=32 or tail{1,k}) {
|
|
c = sel{[8]u32, spread{v}, make{[8]i32, 8**7}}
|
|
}
|
|
store1{rv, j+k, v}
|
|
}
|
|
out{[32]i8~~d, 0}
|
|
}
|
|
|
|
e:= l/32
|
|
@for (xv over i to e) {
|
|
step{xv, i, store}
|
|
}
|
|
|
|
if (e*32 != l) {
|
|
def st{p, j, v} = {
|
|
jv := vl*j
|
|
if (jv+vl <= l) {
|
|
store{p, j, v}
|
|
} else {
|
|
if (jv < l) homMaskStoreF{rv+j, maskOf{V, l - jv}, v}
|
|
return{}
|
|
}
|
|
}
|
|
step{load{xv, e}, e, st}
|
|
}
|
|
}
|
|
export{'si_bcs8', bcs{i8}}
|
|
export{'si_bcs16', bcs{i16}}
|
|
export{'si_bcs32', bcs{i32}}
|
|
|
|
|
|
|
|
def addChk{a:T, b:T} = {
|
|
mem:*T = tup{a}
|
|
def bad = emit{u1, '__builtin_add_overflow', a, b, mem}
|
|
tup{bad, load{mem}}
|
|
}
|
|
def addChk{a:T, b:T & T==f64} = tup{0, a+b}
|
|
|
|
def widenFull{E, xs} = {
|
|
merge{...each{{x:X} => {
|
|
def n = vcount{X}
|
|
def tb = width{E} * n
|
|
if (tb<=arch_defvw) tup{widen{[n]E, x}}
|
|
else if (1) {
|
|
assert{tb == 2*arch_defvw}
|
|
tup{
|
|
widen{[n/2]E, half{x,0}},
|
|
widen{[n/2]E, half{x,1}}
|
|
}
|
|
}
|
|
}, xs}}
|
|
}
|
|
|
|
def floor{x & knum{x}} = x - (x%1)
|
|
def maxabsval{T & issigned{T}} = -minvalue{T}
|
|
def maxsafeint{T & issigned{T}} = maxvalue{T}
|
|
def maxsafeint{T==f64} = 1<<53
|
|
|
|
fn plus_scan{X, R, O}(x:*X, c:R, r:*R, len:u64) : O = {
|
|
i:u64 = 0
|
|
if (hasarch{'AVX2'}) simd_plus_scan_part{X,R}{x, c, r, len, i}
|
|
@forUnroll{1,1} (js from i to len) {
|
|
def vs = eachx{load, x, js}
|
|
each{{j, v} => {
|
|
def {b,n} = addChk{c, promote{R, v}}
|
|
if (rare{b}) return{j}
|
|
store{r, j, n}
|
|
c = n
|
|
}, js, vs}
|
|
}
|
|
len
|
|
}
|
|
# Sum as many vector registers as possible; modifies c and i
|
|
def simd_plus_scan_part{X, R}{x:*X, c:(R), r:*R, len:u64, i:u64} = {
|
|
def b = max{width{R}/2, width{X}}
|
|
def bulk = arch_defvw/b
|
|
|
|
def wd = (X!=R) & (width{X}<32) # whether to widen the working copy one size
|
|
def WE = tern{wd, w_d{X}, X} # working copy element type
|
|
|
|
# maxFastA: max absolute value for accumulator
|
|
# maxFastE: max absolute value for vector elements (not used if ~wd)
|
|
def maxFastE = if (wd) maxabsval{X} else maxabsval{X}/(bulk*tern{R==f64, 1, 4}) # 4 to give maxFastA some range
|
|
def maxFastA = maxsafeint{R} - maxFastE*bulk
|
|
|
|
if (R!=f64) { def m = maxFastA + maxFastE*bulk; assert{m<=maxvalue{R}}; assert{-m>=minvalue{R}} }
|
|
|
|
cv:= [arch_defvw/width{R}]R ** c
|
|
|
|
if (R==f64 and c != floor{c}) goto{'end'}
|
|
while (1) {
|
|
if (R==f64) { if (rare{abs{extract{cv,0}} >= maxFastA}) goto{'end'} }
|
|
else { if (rare{extract{absu{half{cv,0}},0} > maxFastA}) goto{'end'} }
|
|
|
|
i2:= i+bulk
|
|
if (i2>len) goto{'end'}
|
|
|
|
def cx0 = tup{load{*[bulk]X~~(x+i)}}
|
|
def cx = if(wd) widenFull{WE,cx0} else cx0
|
|
if (~wd) { # within-vector overflow check; widening gives range space for this to not happen
|
|
if (rare{homAny{tree_fold{|, each{{c:T} => absu{c} >= (ty_u{T}**maxFastE), cx}}}}) goto{'end'}
|
|
}
|
|
|
|
def s0 = each{scan_plus, cx}
|
|
|
|
def s1{v0} = tup{v0}
|
|
def s1{v0,v1} = tup{v0,v1+toLast{v0}}
|
|
|
|
def cr = eachx{+, widenFull{R, s1{...s0}}, cv}
|
|
cv = toLast{tupsel{-1, cr}}
|
|
|
|
assert{type{cv} == oneType{cr}}
|
|
assert{vcount{type{cv}} * tuplen{cr} == bulk}
|
|
|
|
each{{c:T, j} => store{*T~~(r+i), j, c}, cr, iota{tuplen{cr}}}
|
|
i = i2
|
|
}
|
|
|
|
setlabel{'end'}
|
|
c = extract{cv, 0}
|
|
}
|
|
def plus_scanG{X, R} = plus_scan{X, R, void}
|
|
def plus_scanC{X, R} = plus_scan{X, R, u64}
|
|
|
|
export{'si_scan_plus_i8_i32', plus_scanC{i8, i32}}
|
|
export{'si_scan_plus_i16_i32', plus_scanC{i16, i32}}
|
|
export{'si_scan_plus_i32_i32', plus_scanC{i32, i32}}
|
|
|
|
export{'si_scan_plus_i16_f64', plus_scanG{i16, f64}}
|
|
export{'si_scan_plus_i32_f64', plus_scanG{i32, f64}}
|