Even number handling for modular permutation ⊣˝˘

This commit is contained in:
Marshall Lochbaum 2024-10-29 16:17:15 -04:00
parent 4c55eab740
commit 0bdc43cc0f

View File

@ -74,7 +74,7 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
export{'si_sum_f64', fold_assoc_0{f64,+}}
def select_rows_pow2{T, x0, r0, nv, k} = {
def extract_column_pow2{T, x0, r0, nv, k} = {
def V = [arch_defvw / width{T}]T
xv := *V~~x0
@for (r in *V~~r0 over i to nv) {
@ -109,66 +109,55 @@ def select_rows_pow2{T, x0, r0, nv, k} = {
}
}
fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
n <<= e
def vl = arch_defvw / 8
def vh = vl / 2
def thr = min{vl+2, 20}
if ((not has_simd) or n < vl or l > usz~~thr>>e) return{0}
if (has_simd and (l & (l-1)) == 0) {
nv := n / vl
def try_unzip{T, k} = if (k < thr and l == k) {
select_rows_pow2{T, x0, r0, nv, k}
goto{'ret'}
}
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
@unroll (ek to 4) if (e == ek) {
def T = ty_s{8<<ek}
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
}
return{0}
setlabel{'ret'}; return{(usz~~vl>>e) * nv}
def extract_column_modperm{x0, r0, nv, l, el, vl} = {
# Build modular permutations
def V = [vl]u8
def H = [vl/2]u16
p2 := ctz{l}; l >>= p2 # Decompose into l<<p2 with odd l
e := p2 + promote{ux,el} # Absorb into element size for most computation
l8 := cast_i{u8, l}
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
elo:= V**(u8~~1<<e - 1)
ie := iota{V} & elo
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
def mu16{k} = {
k16 := H ** k
prd := shuf{V~~(kmul * k16), 0,0}
if (e == 0) prd += V~~(k16 << 8)
(prd & V**(vl-1)) + ie
}
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
if (has_blend and (l&1)!=0) {
def V = [vl]u8; def H = [vh]u16
l8 := cast_i{u8, l}
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
elo:= V**(u8~~1<<e - 1)
ie := iota{V} & elo
kmul := make{H, 2*iota{vh}} &~ H~~elo
def mu16{k} = {
k16 := H ** k
prd := shuf{V~~(kmul * k16), 0,0}
if (e == 0) prd += V~~(k16 << 8)
(prd & V**(vl-1)) + ie
}
si := mu16{l8}
sii := mu16{li}
def swap_ms = if (vl == 16) ({x}=>x) else {
ms := (V**16 & sii) == (V**16 &~ iota{V})
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
}
# Blend masks
def mg = { # Iteration i should select where mg == V**i
ss := (si < V**(l8<<e)) & (ie == V**0)
vs := V**0xff - scan_assoc_id0{+}{ss}
swap_ms{shuf{[16]u8, vs, sii}}
}
mgo := mg - V**(l8 & 3)
mgm := (mgo - V**1) & V**3
m4s := @collect (i to 3) mgm == V**i
# Main loop
si := mu16{l8}
sii := mu16{li}
def swap_ms = if (vl == 16) ({x}=>x) else {
ms := (V**16 & sii) == (V**16 &~ iota{V})
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
}
# Blend masks
def mg = { # Iteration i should select where mg == V**i
ss := (si < V**(l8<<e)) & (ie == V**0)
vs := V**0xff - scan_assoc_id0{+}{ss}
swap_ms{shuf{[16]u8, vs, sii}}
}
mgo := mg - V**(l8 & 3)
mgm := (mgo - V**1) & V**3
m4s := @collect (i to 3) mgm == V**i
# Main loop
def loop{output} = {
xv := *V~~x0
nv := n / vl
@for (r in *V~~r0 over i to nv) {
r = load{xv,0}; ++xv
rv := *V~~r0
# Each iteration handles l vectors from x
@for (i to nv<<p2) {
# 1 or 3 initial vectors
r := load{xv,0}; ++xv
if ((l & 2) != 0) {
def {m0, _, m2} = m4s
re := homBlend{load{xv,0}, load{xv,1}, m2}
r = homBlend{re, r, m0}
xv += 2
}
# Then the rest in groups of 4
mh := mgo
@for (l/4) {
{l0, ...ls} := each{load{xv,.}, iota{4}}
@ -176,11 +165,84 @@ fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
r = homBlend{r, r4, mh < V**4}
mh -= V**4; xv += 4
}
r = shuf{[16]u8, swap_ms{r}, si}
def write{r} = {
store{rv, 0, shuf{[16]u8, r, si}}
++rv
}
output{r, i, write}
}
return{(usz~~vl>>e) * nv}
}
0
# Handle odd and even strides separately
if (p2 == 0) {
loop{{r, i, write} => write{swap_ms{r}}}
} else {
# Store results
def add_res = {
ra := V**0 # Accumulator
i := make{V, iota{vl}%16}
o := V**(u8~~1<<el)
bl := i & elo < o
sh := i - o
{r} => { ra = homBlend{shuf{[16]u8, ra, sh}, r, bl} }
}
il := make{V, iota{vl} % 16}
# Shuffle to undo interleaving of add_res
def __shr{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x >> sh)
def __shl{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x << sh)
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8,-sh}
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8, sh}
def uz_lane = {
l := V**(u8~~1<<el - 1)
h := V**(16 - u8~~16>>p2)
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
dz |= (il &~ (l | h))<<p2 # middle->high
shuf{[16]u8, ., dz}
}
# Adjust modular permutation to apply after unzipping
si = uz_lane{(si & V**(vl - u8~~1<<e)) >> p2}
si ^= il &~ V**((16 - u8~~1<<e) >> p2)
# Cross-lane follow-up
def cross = if (vl == 16) { {x}=>x } else {
si = shuf{[4]u64, si, 0,1,0,1}
assert{p2 <= 2}
cr := make{[8]u32, tr_iota{0,2,1}}
if (p2 > 1) cr = make{[8]u32, tr_iota{2,0,1}}
shuf{[8]u32, ., cr}
}
# Run, writing every 1<<p2 steps
plo := usz~~1<<p2 - 1
loop{{r, i, write} => {
def ra = add_res{r}
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
}}
}
}
# Select one element out of every l, element width 1<<el bytes
# Maximum of n result values, return actual number written
fn extract_column(x0:*void, r0:*void, n:usz, l:usz, el:u8) : usz = {
n <<= el
def vl = arch_defvw / 8
def thr = min{vl+2, 20}
if ((not has_simd) or n < vl or l > usz~~thr>>el or l<<el >= thr) return{0}
nv := n / vl
if (has_simd and (l & (l-1)) == 0) {
def try_unzip{T, k} = if (k < thr and l == k) {
extract_column_pow2{T, x0, r0, nv, k}
goto{'ret'}
}
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
@unroll (ek to 4) if (el == ek) {
def T = ty_s{8<<ek}
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
}
return{0}
setlabel{'ret'}
} else if (hasarch{'SSE4.1'} or hasarch{'AARCH64'}) {
extract_column_modperm{x0, r0, nv, l, el, vl}
} else return{0}
(usz~~vl>>el) * nv
}
@ -294,7 +356,7 @@ def fold_rows_bit_lt64{
}
}
fn select_rows_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
fn extract_column_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
assert{l < 64}; assert{o < l} # Row length, and offset within row
def run_loop2{loop} = loop{{a,b} => a>>o}
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
@ -439,5 +501,5 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
}
export{'si_xor_rows_bit', xor_rows_bit}
export{'si_or_rows_bit', or_rows_bit}
export{'si_select_cells_bit_lt64', select_rows_bit_lt64}
export{'si_select_cells_byte', select_rows_byte}
export{'si_select_cells_bit_lt64', extract_column_bit_lt64}
export{'si_select_cells_byte', extract_column}