Extend odd k/bool SSSE3 algorithm to k<16, factor even k≤32

This commit is contained in:
Marshall Lochbaum 2024-08-06 22:16:20 -04:00
parent a439a0a430
commit b4d84041bc

View File

@ -293,9 +293,22 @@ exportT{'si_constrep', each{rep_const, dat_types}}
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
def has_pdep = hasarch{'BMI2'}
if (wv > 32) return{0}
if (hasarch{'SSSE3'} and wv<=8 and wv!=6) {
rep_const_bool_ssse3{wv, x, r, rlen}
return{1}
if (hasarch{'SSSE3'}) {
if (wv&1 == 0) {
p := ctz{wv | 8} # Power of two for second replicate
if (wv>>p == 1) {
rep_const_bool_ssse3_div8{wv, x, r, rlen}
} else {
tlen := rlen>>p
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
rep_const_bool{}(wv >>p, x, t, tlen)
rep_const_bool{}(usz~~1<<p, t, r, rlen)
}
return{1}
} else if (wv < 16) {
rep_const_bool_ssse3_odd{wv, x, r, rlen}
return{1}
}
}
if (not has_pdep and wv <= 8) return{0}
m:u64 = spaced_mask_of{wv}
@ -338,15 +351,15 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
}
1
}
def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
oper // ({a,b}=>floor{a/b}) infix left 40
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
def vl = 16; def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
def W = [2]u64
nv := cdiv{rlen, width{V}}
def run24{x, proc_xv, exh} = {
i:usz = 0; j:usz = 0
rv := *V~~r
@ -359,45 +372,61 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
}
setlabel{end}
}
def run{2} = {
if (wv == 2) {
# Expander for half byte
tabr := mkV{bv{2*iota{4}} * 2b11}
m4 := V**0xf
def exh{x} = selV{tabr, x & m4}
run24{*V~~x, {xv}=>xv, exh}
}
def run{4} = {
dup := mkV{iV//2} # Double each byte
} else if (wv == 4) {
# Double each byte
dup := mkV{iV//2}
# Expander for two bits in either bottom or next-to-bottom position
tabr := mkV{bv{tup{0,4,0,4}} * 2b1111}
m2 := mkV{2b11 << (2*(iV%2))}
def exh{x} = re_el{u16, V}~~selV{tabr, x & m2}
run24{*u64~~x, selV{., dup}, exh}
}
def run{8} = {
} else { # wv == 8
@for (r in *V~~r over i to nv) {
xv := load{*V~~(*u16~~x + i)}
xe := selV{xv, mkV{iV//8}}
r = (xe & mkV{1 << (iV % 8)}) > V**0
}
}
}
# For odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
# - mask out those bits and spread over [ k*i, k*(i+1) )
# - ...except where it crosses words; handle this overhang separately
def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
oper // ({a,b}=>floor{a/b}) infix left 40
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
def vl = 16; def V = [vl]u8
def iV = iota{vl}
def mkV = make{V, .}; def selV = sel{V, ., .}
def W = [2]u64
# For odd numbers:
# - permute each byte sending bit i to position k*i % 8
# - replicate each byte by k, making position k*i contain bit i
# - mask out those bits and spread over [ k*i, k*(i+1) )
# - ...except where it crosses words; handle this overhang separately
# 3 is handled separately and unrolled; 5 and 7 share a loop
def run{3} = {
nv := cdiv{rlen, width{V}}
# Within-byte transformation
def get_ttab{k} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
{t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
m4 := V**0xf
def get_perm_x{i} = {
xv := load{*V~~x, i}
selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
}
# Cases are 3; 5 7; 9 11 13 15
if (wv < 4) {
# 3: dedicated loop
i:usz = 0; j:usz = 0
rv := *V~~r
def end = makelabel{}; while (j < nv) {
xv := load{*V~~x, i}; ++i
# 01234567 to 05316427 on each byte
{t0, t4} := each{{is} => mkV{bv{is}}, split{4, 3*iota{8} % 8}}
m4 := V**0xf
xv = selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
xv := get_perm_x{i}; ++i
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
def ib = ix // 8 # byte index
@ -420,13 +449,11 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
}
}
setlabel{end}
}
def run{{5,7}} = {
{t0, t4, xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
V, V, W, V, V, V, V, V, usz }}
} else if (wv < 8) {
# 5, 7: precompute constants, then shared loop
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
W, V, V, V, V, V, usz }}
def set_consts{k} = {
# Within-byte transformation
tup{t0, t4} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
# Overhang from previous 64-bit elements
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
def ib = ix // 8 # byte index
@ -448,9 +475,7 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
@for (r in *V~~r over j to nv) {
--q; if (q == 0) { q = wv
# Load and permute bytes
xv = load{*V~~x, i}; ++i
m4 := V**0xf
xv = selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
xv = get_perm_x{i}; ++i
# Bytes for overhang
xo = V~~((W~~xv & xom) >> (8 - wv))
xo += xo > V**0
@ -468,12 +493,36 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
r = V~~((b<<wv) - b)
r |= xo & mkV{255 * (iV%8 == 0)} # overhang
}
} else { # wv < 16
# 9, 11, 13, 15: shared constant computation and loop
# Initializers, relying on vl%wv == vl-wv in various ways
wV:= V**cast_i{u8,wv}
ind0 := (iota{V} < wV) + V**1 # mkV{iV >= k}
{m, d} := unaligned_spaced_mask_mod{wv}
mask0:= V~~make{W, m, m>>d|m<<(wv-d)}
iu:= iota{V} + V**cast_i{u8,vl-wv}; ia:= iu>=V**vl
ind_up := iu - (wV&ia)
ind_inc:= V**(wv <= vl) - ia
mask_sh:= d+d; if (mask_sh >= wv) mask_sh-= wv
# State
xv:V = V**0; o:=W**0; ind:=xv; mask:=xv
i:usz = 0; q:usz = 1
@for (r in *V~~r over j to nv) {
--q; if (q == 0) { q = wv
xv = get_perm_x{i}; ++i
ind = ind0
mask = mask0
} else {
ind = selV{ind, ind_up} + ind_inc
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
}
# Handle overhang here; won't fit in a single vector
b := W~~(selV{xv, ind} & mask)
r = V~~((b<<wv) - b)
po:= o; o = b>>(64-wv)
ro:= [4]u32~~vshl{po, o, 1}
r |= V~~(ro + (ro > [4]u32**0))
}
}
if (wv==2) run{2}
else if (wv==3) run{3}
else if (wv==4) run{4}
else if (wv==8) run{8}
else run{tup{5,7}}
}
export{'si_constrep_bool', rep_const_bool{}}