Even number handling for modular permutation ⊣˝˘
This commit is contained in:
parent
4c55eab740
commit
0bdc43cc0f
@ -74,7 +74,7 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
|
||||
export{'si_sum_f64', fold_assoc_0{f64,+}}
|
||||
|
||||
|
||||
def select_rows_pow2{T, x0, r0, nv, k} = {
|
||||
def extract_column_pow2{T, x0, r0, nv, k} = {
|
||||
def V = [arch_defvw / width{T}]T
|
||||
xv := *V~~x0
|
||||
@for (r in *V~~r0 over i to nv) {
|
||||
@ -109,66 +109,55 @@ def select_rows_pow2{T, x0, r0, nv, k} = {
|
||||
}
|
||||
}
|
||||
|
||||
fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
|
||||
n <<= e
|
||||
def vl = arch_defvw / 8
|
||||
def vh = vl / 2
|
||||
def thr = min{vl+2, 20}
|
||||
if ((not has_simd) or n < vl or l > usz~~thr>>e) return{0}
|
||||
if (has_simd and (l & (l-1)) == 0) {
|
||||
nv := n / vl
|
||||
def try_unzip{T, k} = if (k < thr and l == k) {
|
||||
select_rows_pow2{T, x0, r0, nv, k}
|
||||
goto{'ret'}
|
||||
}
|
||||
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
|
||||
@unroll (ek to 4) if (e == ek) {
|
||||
def T = ty_s{8<<ek}
|
||||
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
|
||||
}
|
||||
return{0}
|
||||
setlabel{'ret'}; return{(usz~~vl>>e) * nv}
|
||||
def extract_column_modperm{x0, r0, nv, l, el, vl} = {
|
||||
# Build modular permutations
|
||||
def V = [vl]u8
|
||||
def H = [vl/2]u16
|
||||
p2 := ctz{l}; l >>= p2 # Decompose into l<<p2 with odd l
|
||||
e := p2 + promote{ux,el} # Absorb into element size for most computation
|
||||
l8 := cast_i{u8, l}
|
||||
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
|
||||
elo:= V**(u8~~1<<e - 1)
|
||||
ie := iota{V} & elo
|
||||
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
|
||||
def mu16{k} = {
|
||||
k16 := H ** k
|
||||
prd := shuf{V~~(kmul * k16), 0,0}
|
||||
if (e == 0) prd += V~~(k16 << 8)
|
||||
(prd & V**(vl-1)) + ie
|
||||
}
|
||||
def has_blend = hasarch{'SSE4.1'} or hasarch{'AARCH64'}
|
||||
if (has_blend and (l&1)!=0) {
|
||||
def V = [vl]u8; def H = [vh]u16
|
||||
l8 := cast_i{u8, l}
|
||||
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
|
||||
elo:= V**(u8~~1<<e - 1)
|
||||
ie := iota{V} & elo
|
||||
kmul := make{H, 2*iota{vh}} &~ H~~elo
|
||||
def mu16{k} = {
|
||||
k16 := H ** k
|
||||
prd := shuf{V~~(kmul * k16), 0,0}
|
||||
if (e == 0) prd += V~~(k16 << 8)
|
||||
(prd & V**(vl-1)) + ie
|
||||
}
|
||||
si := mu16{l8}
|
||||
sii := mu16{li}
|
||||
def swap_ms = if (vl == 16) ({x}=>x) else {
|
||||
ms := (V**16 & sii) == (V**16 &~ iota{V})
|
||||
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
|
||||
}
|
||||
# Blend masks
|
||||
def mg = { # Iteration i should select where mg == V**i
|
||||
ss := (si < V**(l8<<e)) & (ie == V**0)
|
||||
vs := V**0xff - scan_assoc_id0{+}{ss}
|
||||
swap_ms{shuf{[16]u8, vs, sii}}
|
||||
}
|
||||
mgo := mg - V**(l8 & 3)
|
||||
mgm := (mgo - V**1) & V**3
|
||||
m4s := @collect (i to 3) mgm == V**i
|
||||
# Main loop
|
||||
si := mu16{l8}
|
||||
sii := mu16{li}
|
||||
def swap_ms = if (vl == 16) ({x}=>x) else {
|
||||
ms := (V**16 & sii) == (V**16 &~ iota{V})
|
||||
{x} => homBlend{x, shuf{[4]u64, x, 2,3,0,1}, ms}
|
||||
}
|
||||
|
||||
# Blend masks
|
||||
def mg = { # Iteration i should select where mg == V**i
|
||||
ss := (si < V**(l8<<e)) & (ie == V**0)
|
||||
vs := V**0xff - scan_assoc_id0{+}{ss}
|
||||
swap_ms{shuf{[16]u8, vs, sii}}
|
||||
}
|
||||
mgo := mg - V**(l8 & 3)
|
||||
mgm := (mgo - V**1) & V**3
|
||||
m4s := @collect (i to 3) mgm == V**i
|
||||
|
||||
# Main loop
|
||||
def loop{output} = {
|
||||
xv := *V~~x0
|
||||
nv := n / vl
|
||||
@for (r in *V~~r0 over i to nv) {
|
||||
r = load{xv,0}; ++xv
|
||||
rv := *V~~r0
|
||||
# Each iteration handles l vectors from x
|
||||
@for (i to nv<<p2) {
|
||||
# 1 or 3 initial vectors
|
||||
r := load{xv,0}; ++xv
|
||||
if ((l & 2) != 0) {
|
||||
def {m0, _, m2} = m4s
|
||||
re := homBlend{load{xv,0}, load{xv,1}, m2}
|
||||
r = homBlend{re, r, m0}
|
||||
xv += 2
|
||||
}
|
||||
# Then the rest in groups of 4
|
||||
mh := mgo
|
||||
@for (l/4) {
|
||||
{l0, ...ls} := each{load{xv,.}, iota{4}}
|
||||
@ -176,11 +165,84 @@ fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
|
||||
r = homBlend{r, r4, mh < V**4}
|
||||
mh -= V**4; xv += 4
|
||||
}
|
||||
r = shuf{[16]u8, swap_ms{r}, si}
|
||||
def write{r} = {
|
||||
store{rv, 0, shuf{[16]u8, r, si}}
|
||||
++rv
|
||||
}
|
||||
output{r, i, write}
|
||||
}
|
||||
return{(usz~~vl>>e) * nv}
|
||||
}
|
||||
0
|
||||
|
||||
# Handle odd and even strides separately
|
||||
if (p2 == 0) {
|
||||
loop{{r, i, write} => write{swap_ms{r}}}
|
||||
} else {
|
||||
# Store results
|
||||
def add_res = {
|
||||
ra := V**0 # Accumulator
|
||||
i := make{V, iota{vl}%16}
|
||||
o := V**(u8~~1<<el)
|
||||
bl := i & elo < o
|
||||
sh := i - o
|
||||
{r} => { ra = homBlend{shuf{[16]u8, ra, sh}, r, bl} }
|
||||
}
|
||||
il := make{V, iota{vl} % 16}
|
||||
# Shuffle to undo interleaving of add_res
|
||||
def __shr{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x >> sh)
|
||||
def __shl{x:(V), sh if hasarch{'X86_64'}} = V~~(H~~x << sh)
|
||||
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8,-sh}
|
||||
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not isvec{T}} = x << V**cast_i{u8, sh}
|
||||
def uz_lane = {
|
||||
l := V**(u8~~1<<el - 1)
|
||||
h := V**(16 - u8~~16>>p2)
|
||||
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
|
||||
dz |= (il &~ (l | h))<<p2 # middle->high
|
||||
shuf{[16]u8, ., dz}
|
||||
}
|
||||
# Adjust modular permutation to apply after unzipping
|
||||
si = uz_lane{(si & V**(vl - u8~~1<<e)) >> p2}
|
||||
si ^= il &~ V**((16 - u8~~1<<e) >> p2)
|
||||
# Cross-lane follow-up
|
||||
def cross = if (vl == 16) { {x}=>x } else {
|
||||
si = shuf{[4]u64, si, 0,1,0,1}
|
||||
assert{p2 <= 2}
|
||||
cr := make{[8]u32, tr_iota{0,2,1}}
|
||||
if (p2 > 1) cr = make{[8]u32, tr_iota{2,0,1}}
|
||||
shuf{[8]u32, ., cr}
|
||||
}
|
||||
# Run, writing every 1<<p2 steps
|
||||
plo := usz~~1<<p2 - 1
|
||||
loop{{r, i, write} => {
|
||||
def ra = add_res{r}
|
||||
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
# Select one element out of every l, element width 1<<el bytes
|
||||
# Maximum of n result values, return actual number written
|
||||
fn extract_column(x0:*void, r0:*void, n:usz, l:usz, el:u8) : usz = {
|
||||
n <<= el
|
||||
def vl = arch_defvw / 8
|
||||
def thr = min{vl+2, 20}
|
||||
if ((not has_simd) or n < vl or l > usz~~thr>>el or l<<el >= thr) return{0}
|
||||
nv := n / vl
|
||||
if (has_simd and (l & (l-1)) == 0) {
|
||||
def try_unzip{T, k} = if (k < thr and l == k) {
|
||||
extract_column_pow2{T, x0, r0, nv, k}
|
||||
goto{'ret'}
|
||||
}
|
||||
# 10 loops: i8 2,4,8,16; i16 2,4,8; i32 2,4; i64 2
|
||||
@unroll (ek to 4) if (el == ek) {
|
||||
def T = ty_s{8<<ek}
|
||||
@unroll (p from 1 to 5-ek) try_unzip{T, 1<<p}
|
||||
}
|
||||
return{0}
|
||||
setlabel{'ret'}
|
||||
} else if (hasarch{'SSE4.1'} or hasarch{'AARCH64'}) {
|
||||
extract_column_modperm{x0, r0, nv, l, el, vl}
|
||||
} else return{0}
|
||||
(usz~~vl>>el) * nv
|
||||
}
|
||||
|
||||
|
||||
@ -294,7 +356,7 @@ def fold_rows_bit_lt64{
|
||||
}
|
||||
}
|
||||
|
||||
fn select_rows_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
|
||||
fn extract_column_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
|
||||
assert{l < 64}; assert{o < l} # Row length, and offset within row
|
||||
def run_loop2{loop} = loop{{a,b} => a>>o}
|
||||
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
|
||||
@ -439,5 +501,5 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
|
||||
}
|
||||
export{'si_xor_rows_bit', xor_rows_bit}
|
||||
export{'si_or_rows_bit', or_rows_bit}
|
||||
export{'si_select_cells_bit_lt64', select_rows_bit_lt64}
|
||||
export{'si_select_cells_byte', select_rows_byte}
|
||||
export{'si_select_cells_bit_lt64', extract_column_bit_lt64}
|
||||
export{'si_select_cells_byte', extract_column}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user