Extend modular permutation ⊣˝˘ to multi-byte element types

This commit is contained in:
Marshall Lochbaum 2024-10-24 22:21:35 -04:00
parent 9997c52e4c
commit 4669418f1c
2 changed files with 14 additions and 8 deletions

View File

@ -220,7 +220,7 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
// fast special-case implementations
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz);
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
ur xr = RNK(x);
assert(xr>1 && k<xr);
@ -268,7 +268,7 @@ static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind
} else {
usz i0 = 0;
#if SINGELI
if (xl==3) i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l);
i0 = si_select_cells_byte((u8*)xp + (ind<<(xl-3)), rp, cam, l, xl-3);
#endif
switch(xl) { default: UD;
case 3: PLAINLOOP for (usz i=i0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*l+ind]; break;

View File

@ -74,23 +74,29 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = {
export{'si_sum_f64', fold_assoc_0{f64,+}}
fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz) : usz = {
if (hasarch{'AVX2'} and n>=32 and l<20 and (l&1)==1) {
fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz, e:u8) : usz = {
n <<= e
if (hasarch{'AVX2'} and n>=32 and l < usz~~20>>e and (l&1)!=0) {
def V = [32]u8
def swap_by{x, m} = homBlend{x, shuf{[4]u64, x, 2,3,0,1}, m}
l8 := cast_i{u8, l}
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod 32
ie := make{[16]u16, 2*iota{16}}
elo:= V**(u8~~1<<e - 1)
ie := iota{V} & elo
kmul := make{[16]u16, 2*iota{16}} &~ [16]u16~~elo
def mu16{k} = {
k16 := [16]u16 ** k
(shuf{V~~(ie * k16), 0,0} + V~~(k16 << 8)) & V**31
prd := shuf{V~~(kmul * k16), 0,0}
if (e == 0) prd += V~~(k16 << 8)
(prd & V**31) + ie
}
sii := mu16{li}
ms := (V**16 & sii) == (V**16 &~ iota{V})
si := mu16{l8}
# Blend masks
def mg = { # Iteration i should select where mg == V**i
vs := V**0xff - scan_assoc_id0{+}{si < V**l8}
ss := (si < V**(l8<<e)) & (ie == V**0)
vs := V**0xff - scan_assoc_id0{+}{ss}
swap_by{shuf{[16]u8, vs, sii}, ms}
}
mgo := mg - V**(l8 & 3)
@ -116,7 +122,7 @@ fn select_rows_byte(x0:*void, r0:*void, n:usz, l:usz) : usz = {
}
r = shuf{[16]u8, swap_by{r, ms}, si}
}
return{32 * nv}
return{(usz~~32>>e) * nv}
}
0
}