Extend select_rows_byte to any architecture with (shuffle and) blend

This commit is contained in:
Marshall Lochbaum 2024-10-25 15:31:29 -04:00
parent 4669418f1c
commit 6e4859e17f

View File

@ -76,35 +76,41 @@ export{'si_sum_f64', fold_assoc_0{f64,+}}
fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = { fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
n <<= e n <<= e
if (hasarch{'AVX2'} and n>=32 and l < usz~~20>>e and (l&1)!=0) { def vl = arch_defvw / 8
def V = [32]u8 def vh = vl / 2
def swap_by{x, m} = homBlend{x, shuf{[4]u64, x, 2,3,0,1}, m} def thr = min{vl+2, 20}
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
if (has_blend and n>=vl and l < usz~~thr>>e and (l&1)!=0) {
def V = [vl]u8; def H = [vh]u16
l8 := cast_i{u8, l} l8 := cast_i{u8, l}
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod 32 li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
elo:= V**(u8~~1<<e - 1) elo:= V**(u8~~1<<e - 1)
ie := iota{V} & elo ie := iota{V} & elo
kmul := make{[16]u16, 2*iota{16}} &~ [16]u16~~elo kmul := make{H, 2*iota{vh}} &~ H~~elo
def mu16{k} = { def mu16{k} = {
k16 := [16]u16 ** k k16 := H ** k
prd := shuf{V~~(kmul * k16), 0,0} prd := shuf{V~~(kmul * k16), 0,0}
if (e == 0) prd += V~~(k16 << 8) if (e == 0) prd += V~~(k16 << 8)
(prd & V**31) + ie (prd & V**(vl-1)) + ie
} }
sii := mu16{li}
ms := (V**16 & sii) == (V**16 &~ iota{V})
si := mu16{l8} si := mu16{l8}
sii := mu16{li}
def swap_ms = if (vl == 16) ({x}=>x) else {
ms := (V**16 & sii) == (V**16 &~ iota{V})
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
}
# Blend masks # Blend masks
def mg = { # Iteration i should select where mg == V**i def mg = { # Iteration i should select where mg == V**i
ss := (si < V**(l8<<e)) & (ie == V**0) ss := (si < V**(l8<<e)) & (ie == V**0)
vs := V**0xff - scan_assoc_id0{+}{ss} vs := V**0xff - scan_assoc_id0{+}{ss}
swap_by{shuf{[16]u8, vs, sii}, ms} swap_ms{shuf{[16]u8, vs, sii}}
} }
mgo := mg - V**(l8 & 3) mgo := mg - V**(l8 & 3)
mgm := (mgo - V**1) & V**3 mgm := (mgo - V**1) & V**3
m4s := @collect (i to 3) mgm == V**i m4s := @collect (i to 3) mgm == V**i
# Main loop # Main loop
xv := *V~~x0 xv := *V~~x0
nv := n / 32 nv := n / vl
@for (r in *V~~r0 over i to nv) { @for (r in *V~~r0 over i to nv) {
r = load{xv,0}; ++xv r = load{xv,0}; ++xv
if ((l & 2) != 0) { if ((l & 2) != 0) {
@ -120,9 +126,9 @@ fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
r = homBlend{r, r4, mh < V**4} r = homBlend{r, r4, mh < V**4}
mh -= V**4; xv += 4 mh -= V**4; xv += 4
} }
r = shuf{[16]u8, swap_by{r, ms}, si} r = shuf{[16]u8, swap_ms{r}, si}
} }
return{(usz~~32>>e) * nv} return{(usz~~vl>>e) * nv}
} }
0 0
} }