Strided ⌊⌈ scans with shuffles

This commit is contained in:
Marshall Lochbaum 2025-03-01 08:42:36 -05:00
parent 360e4e8320
commit 87a7d066c8
2 changed files with 74 additions and 0 deletions

View File

@ -273,6 +273,13 @@ B scan_c1(Md1D* d, B x) { B f = d->f;
if (neg) r = bit_negate(r);
decG(x); return r;
}
if (rtid==n_floor | rtid==n_ceil) {
// boolean was handled as CASE_N_AND
B r; void* rp = m_tyarrc(&r, elWidth(xe), x, el2t(xe));
void* xp = tyany_ptr(x);
si_scan_stride_minmax[4*(rtid==n_ceil) + xe-el_i8](xp, rp, ia, csz);
decG(x); return r;
}
#endif
goto base;
}}

View File

@ -98,6 +98,73 @@ def shift_first{c:V=[l]_, p:V} = {
else blend_first{c, rotate_right{p}}
}
# Strided scans
fn scan_stride_assoc{op, T}(xv:*void, rv:*void, ia:usz, l:usz) : void = {
def id = match (op) { {(min)} => maxvalue; {(max)} => minvalue }
def f = width{T}/8; def vl = 16/f
x:= *T~~xv; r:= *T~~rv
def has_shuf = hasarch{'SSSE3'} or hasarch{'AARCH64'}
if (has_shuf and T<=i32 and l<vl) {
def small{k} = {
def I = [16]i8
iv:= iota{I}; j:= I**cast_i{i8,l*f}
spr:= I**16 - j + iv
inds:= @collect (k) {
v:= iv - (j &~ I~~(iv<j))
j += j
spr = shuf{spr, v}
v
}
def V = [vl]T
c:= V**id{T}
@for_masked{vl} (x in tup{V, x}, r in tup{V, r} over ia) {
def sc{v, i} = op{shuf{i8, v, i}, v}
r = c = op{shuf{i8, c, spr}, fold{sc, x, inds}}
}
}
if (f==1 and l<4) small{3} else small{if (f<=2) 2 else 1}
} else if (has_simd and T==f64 and l==vl) {
def V = [vl]f64
p:= load{*V~~x}; store{*V~~r, 0, p}
@for (r in *V~~r, x in *V~~x over _ from 1 to ia/vl) {
r = p = op{p, x}
}
} else {
@for (r, x over l) r = x
if (has_shuf and T<=i32 and l<256/f) {
def I = [16]i8
q:= l%vl; fq:= cast_i{i8, f*q}
def rot = shuf{i8, ., (iota{I} - I**fq) & I**15}
bv:= iota{I} >= I**fq; def bl = blend_hom{..., bv}
def V = [vl]T
c:= V**id{T}
o:= l - q
if (l == 2*vl) { o = vl; bv = ~bv }
if (o == vl) {
p:= load{*V~~x}; store{*V~~r, 0, p}
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o} over ia-o) {
p = rot{p}
r = op{bl{c, p}, x}
c = p; p = r
}
} else {
@for_masked{vl} (x in tup{V, x+o}, r in tup{V, r+o}, p in tup{V, r} over ia-o) {
q:= rot{p}
r = op{bl{c, q}, x}
c = q
}
}
} else {
@for (r, x, p in r-l over _ from l to ia) r = op{p, x}
}
}
}
export_tab{'si_scan_stride_minmax',
flat_table{scan_stride_assoc, tup{min,max}, tup{i8,i16,i32,f64}}
}
# xor scan
def vec_prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v