Strided scan with AVX2 [8]i32 permute if possible
This commit is contained in:
parent
180c79e751
commit
4072bde806
@ -101,49 +101,69 @@ def shift_first{c:V=[l]_, p:V} = {
|
|||||||
|
|
||||||
# Strided scans
|
# Strided scans
|
||||||
fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz) : Ret = {
|
fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz) : Ret = {
|
||||||
|
def minvalue{(f64)} = -1/0; def maxvalue{(f64)} = 1/0
|
||||||
def id = match (op) {
|
def id = match (op) {
|
||||||
{(min)} => maxvalue; {(max)} => minvalue
|
{(min)} => maxvalue; {(max)} => minvalue
|
||||||
{(+)} => ({_}=>0)
|
{(+)} => ({_}=>0)
|
||||||
}
|
}
|
||||||
def f = width{T}/8; def vl = 16/f
|
|
||||||
x:= *T~~xv; r:= *T~~rv
|
x:= *T~~xv; r:= *T~~rv
|
||||||
|
# Architecture determination
|
||||||
|
# Use largest vector width with a full-width shuffle
|
||||||
def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'}
|
def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'}
|
||||||
if (has_shuf and T<=i32 and l<vl) {
|
def I = if (hasarch{'AVX2'} and T>=i32) [8]i32 else [16]i8
|
||||||
|
def [il]IE = I; def selI = shuf{IE, ...}
|
||||||
|
def wT = width{T}
|
||||||
|
def f = wT/width{IE}
|
||||||
|
def vl = width{I}/wT
|
||||||
|
def V = [vl]T
|
||||||
|
if (has_shuf and l < vl) {
|
||||||
|
# Small stride: power-of-two shifts
|
||||||
def small{k} = {
|
def small{k} = {
|
||||||
def I = [16]i8
|
iv:= iota{I}; j:= I**cast_i{IE,l*f}
|
||||||
iv:= iota{I}; j:= I**cast_i{i8,l*f}
|
spr:= I**il - j + iv
|
||||||
spr:= I**16 - j + iv
|
def inds = @collect (k) {
|
||||||
inds:= @collect (k) {
|
|
||||||
v:= iv - (j &~ I~~(iv<j))
|
v:= iv - (j &~ I~~(iv<j))
|
||||||
spr = shuf{spr, v}
|
spr = selI{spr, v}
|
||||||
if (same{op, +}) v = iv - j
|
js:= j; j+= j
|
||||||
j += j
|
if (not same{op, +}) selI{., v}
|
||||||
v
|
else if (same{IE,i8}) selI{., iv - js}
|
||||||
|
else { m:= V~~(iv >= js); {x} => selI{x, v} & m }
|
||||||
}
|
}
|
||||||
def V = [vl]T
|
|
||||||
c:= V**id{T}
|
c:= V**id{T}
|
||||||
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||||
def sc{v, i} = op{shuf{i8, v, i}, v}
|
xs:= fold{{v, i} => op{i{v}, v}, x, inds}
|
||||||
r = c = op{shuf{i8, c, spr}, fold{sc, x, inds}}
|
r = c = op{shuf{IE, c, spr}, xs}
|
||||||
check_over{x, r} # For +, infers other argument as r-x
|
check_over{x, r} # For +, infers other argument as r-x
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (f==1 and l<4) small{3} else small{if (f<=2) 2 else 1}
|
if (not (same{op,+} and V==[4]f64)) {
|
||||||
} else if (has_simd and T==f64 and l==vl) {
|
def max_k = lb{vl/2}
|
||||||
def V = [vl]f64
|
if (max_k<3 or l<4) small{max_k} else small{max_k-1}
|
||||||
p:= load{*V~~x}; store{*V~~r, 0, p}
|
} else { # Non-associative!
|
||||||
@for (r in *V~~r, x in *V~~x over _ from 1 to ia/vl) {
|
c:= V**0
|
||||||
r = p = op{p, x}
|
if (l==2) {
|
||||||
|
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||||
|
a:= c + shuf{x, 0,1,0,1}
|
||||||
|
c = a + shuf{x, 2,3,2,3}
|
||||||
|
r = blend{a, c, 0,0,1,1}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert{l==3}
|
||||||
|
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
|
||||||
|
a:= shuf{c, 1,1,2,3} + blend{x, V**0, 0,1,1,1}
|
||||||
|
r = c = x + shuf{a, 1,2,3,0}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
# Large stride: single shift, with saved register or memory
|
||||||
def op_chk{p, x} = { r:= op{p, x}; check_over{p, x, r}; r }
|
def op_chk{p, x} = { r:= op{p, x}; check_over{p, x, r}; r }
|
||||||
@for (r, x over l) r = x
|
@for (r, x over l) r = x
|
||||||
if (has_shuf and T<=i32 and l<256/f) {
|
if (has_shuf and l<256/(wT/8)) {
|
||||||
def I = [16]i8
|
def [il]IE = I
|
||||||
q:= l%vl; fq:= cast_i{i8, f*q}
|
q:= l%vl; fq:= cast_i{IE, q*f}
|
||||||
def rot = shuf{i8, ., (iota{I} - I**fq) & I**15}
|
def rot = shuf{IE, ., (iota{I} - I**fq) & I**(il-1)}
|
||||||
bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv}
|
bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv}
|
||||||
def V = [vl]T
|
|
||||||
c:= V**id{T}
|
c:= V**id{T}
|
||||||
o:= l - q
|
o:= l - q
|
||||||
if (l == 2*vl) { o = vl; bv = ~bv }
|
if (l == 2*vl) { o = vl; bv = ~bv }
|
||||||
@ -162,7 +182,7 @@ fn scan_stride_assoc{op, T, Ret, check_over}(xv:*void, rv:*void, ia:usz, l:usz)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (same{op, +} and T<=i32 and has_simd) {
|
} else if (same{op, +} and T<=i32 and has_simd) {
|
||||||
def vl = arch_defvw/width{T}; def V = [vl]T
|
def vl = arch_defvw/wT; def V = [vl]T
|
||||||
@for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r} over ia-l) {
|
@for_masked{vl} (x in tup{V, x+l}, r in tup{V, r+l}, p in tup{V, r} over ia-l) {
|
||||||
r = op_chk{p, x}
|
r = op_chk{p, x}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user