Extend odd k/bool SSSE3 algorithm to k<16, factor even k≤32
This commit is contained in:
parent
a439a0a430
commit
b4d84041bc
@ -293,9 +293,22 @@ exportT{'si_constrep', each{rep_const, dat_types}}
|
|||||||
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
||||||
def has_pdep = hasarch{'BMI2'}
|
def has_pdep = hasarch{'BMI2'}
|
||||||
if (wv > 32) return{0}
|
if (wv > 32) return{0}
|
||||||
if (hasarch{'SSSE3'} and wv<=8 and wv!=6) {
|
if (hasarch{'SSSE3'}) {
|
||||||
rep_const_bool_ssse3{wv, x, r, rlen}
|
if (wv&1 == 0) {
|
||||||
|
p := ctz{wv | 8} # Power of two for second replicate
|
||||||
|
if (wv>>p == 1) {
|
||||||
|
rep_const_bool_ssse3_div8{wv, x, r, rlen}
|
||||||
|
} else {
|
||||||
|
tlen := rlen>>p
|
||||||
|
t := r + cdiv{rlen, 64} - cdiv{tlen, 64}
|
||||||
|
rep_const_bool{}(wv >>p, x, t, tlen)
|
||||||
|
rep_const_bool{}(usz~~1<<p, t, r, rlen)
|
||||||
|
}
|
||||||
return{1}
|
return{1}
|
||||||
|
} else if (wv < 16) {
|
||||||
|
rep_const_bool_ssse3_odd{wv, x, r, rlen}
|
||||||
|
return{1}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (not has_pdep and wv <= 8) return{0}
|
if (not has_pdep and wv <= 8) return{0}
|
||||||
m:u64 = spaced_mask_of{wv}
|
m:u64 = spaced_mask_of{wv}
|
||||||
@ -338,15 +351,15 @@ fn rep_const_bool{}(wv:usz, x:*u64, r:*u64, rlen:usz) : u1 = {
|
|||||||
}
|
}
|
||||||
1
|
1
|
||||||
}
|
}
|
||||||
def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
|
|
||||||
|
def rep_const_bool_ssse3_div8{wv, x, r, rlen} = { # wv in 2,4,8
|
||||||
oper // ({a,b}=>floor{a/b}) infix left 40
|
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||||
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
|
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
|
||||||
def vl = 16; def V = [vl]u8
|
def vl = 16; def V = [vl]u8
|
||||||
def iV = iota{vl}
|
def iV = iota{vl}
|
||||||
def mkV = make{V, .}; def selV = sel{V, ., .}
|
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||||
def W = [2]u64
|
|
||||||
|
|
||||||
nv := cdiv{rlen, width{V}}
|
nv := cdiv{rlen, width{V}}
|
||||||
|
|
||||||
def run24{x, proc_xv, exh} = {
|
def run24{x, proc_xv, exh} = {
|
||||||
i:usz = 0; j:usz = 0
|
i:usz = 0; j:usz = 0
|
||||||
rv := *V~~r
|
rv := *V~~r
|
||||||
@ -359,45 +372,61 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
|
|||||||
}
|
}
|
||||||
setlabel{end}
|
setlabel{end}
|
||||||
}
|
}
|
||||||
def run{2} = {
|
if (wv == 2) {
|
||||||
# Expander for half byte
|
# Expander for half byte
|
||||||
tabr := mkV{bv{2*iota{4}} * 2b11}
|
tabr := mkV{bv{2*iota{4}} * 2b11}
|
||||||
m4 := V**0xf
|
m4 := V**0xf
|
||||||
def exh{x} = selV{tabr, x & m4}
|
def exh{x} = selV{tabr, x & m4}
|
||||||
run24{*V~~x, {xv}=>xv, exh}
|
run24{*V~~x, {xv}=>xv, exh}
|
||||||
}
|
} else if (wv == 4) {
|
||||||
def run{4} = {
|
# Double each byte
|
||||||
dup := mkV{iV//2} # Double each byte
|
dup := mkV{iV//2}
|
||||||
# Expander for two bits in either bottom or next-to-bottom position
|
# Expander for two bits in either bottom or next-to-bottom position
|
||||||
tabr := mkV{bv{tup{0,4,0,4}} * 2b1111}
|
tabr := mkV{bv{tup{0,4,0,4}} * 2b1111}
|
||||||
m2 := mkV{2b11 << (2*(iV%2))}
|
m2 := mkV{2b11 << (2*(iV%2))}
|
||||||
def exh{x} = re_el{u16, V}~~selV{tabr, x & m2}
|
def exh{x} = re_el{u16, V}~~selV{tabr, x & m2}
|
||||||
run24{*u64~~x, selV{., dup}, exh}
|
run24{*u64~~x, selV{., dup}, exh}
|
||||||
}
|
} else { # wv == 8
|
||||||
|
|
||||||
def run{8} = {
|
|
||||||
@for (r in *V~~r over i to nv) {
|
@for (r in *V~~r over i to nv) {
|
||||||
xv := load{*V~~(*u16~~x + i)}
|
xv := load{*V~~(*u16~~x + i)}
|
||||||
xe := selV{xv, mkV{iV//8}}
|
xe := selV{xv, mkV{iV//8}}
|
||||||
r = (xe & mkV{1 << (iV % 8)}) > V**0
|
r = (xe & mkV{1 << (iV % 8)}) > V**0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
# For odd numbers:
|
||||||
|
# - permute each byte sending bit i to position k*i % 8
|
||||||
|
# - replicate each byte by k, making position k*i contain bit i
|
||||||
|
# - mask out those bits and spread over [ k*i, k*(i+1) )
|
||||||
|
# - ...except where it crosses words; handle this overhang separately
|
||||||
|
def rep_const_bool_ssse3_odd{wv, x, r, rlen} = { # wv odd, wv<=15
|
||||||
|
oper // ({a,b}=>floor{a/b}) infix left 40
|
||||||
|
def bv{bs} = fold{flat_table{|,...}, reverse{each{tup{0,.}, 1<<bs}}}
|
||||||
|
def vl = 16; def V = [vl]u8
|
||||||
|
def iV = iota{vl}
|
||||||
|
def mkV = make{V, .}; def selV = sel{V, ., .}
|
||||||
|
def W = [2]u64
|
||||||
|
|
||||||
# For odd numbers:
|
nv := cdiv{rlen, width{V}}
|
||||||
# - permute each byte sending bit i to position k*i % 8
|
|
||||||
# - replicate each byte by k, making position k*i contain bit i
|
# Within-byte transformation
|
||||||
# - mask out those bits and spread over [ k*i, k*(i+1) )
|
def get_ttab{k} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
|
||||||
# - ...except where it crosses words; handle this overhang separately
|
ttab:*V = join{each{get_ttab, 2*iota{4} + 1}}
|
||||||
# 3 is handled separately and unrolled; 5 and 7 share a loop
|
{t0, t4} := each{load{ttab + (wv & 6), .}, iota{2}}
|
||||||
def run{3} = {
|
m4 := V**0xf
|
||||||
|
def get_perm_x{i} = {
|
||||||
|
xv := load{*V~~x, i}
|
||||||
|
selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Cases are 3; 5 7; 9 11 13 15
|
||||||
|
if (wv < 4) {
|
||||||
|
# 3: dedicated loop
|
||||||
i:usz = 0; j:usz = 0
|
i:usz = 0; j:usz = 0
|
||||||
rv := *V~~r
|
rv := *V~~r
|
||||||
def end = makelabel{}; while (j < nv) {
|
def end = makelabel{}; while (j < nv) {
|
||||||
xv := load{*V~~x, i}; ++i
|
|
||||||
# 01234567 to 05316427 on each byte
|
# 01234567 to 05316427 on each byte
|
||||||
{t0, t4} := each{{is} => mkV{bv{is}}, split{4, 3*iota{8} % 8}}
|
xv := get_perm_x{i}; ++i
|
||||||
m4 := V**0xf
|
|
||||||
xv = selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
|
|
||||||
# Overhang from previous 64-bit elements
|
# Overhang from previous 64-bit elements
|
||||||
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
|
def ix = 64*slice{iota{3},1} // 3 # bits that overhang within a word
|
||||||
def ib = ix // 8 # byte index
|
def ib = ix // 8 # byte index
|
||||||
@ -420,13 +449,11 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
setlabel{end}
|
setlabel{end}
|
||||||
}
|
} else if (wv < 8) {
|
||||||
def run{{5,7}} = {
|
# 5, 7: precompute constants, then shared loop
|
||||||
{t0, t4, xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
|
{xom, xse, ind0, mask0, ind_up, ind_inc, mask_sh} := undef{tup{
|
||||||
V, V, W, V, V, V, V, V, usz }}
|
W, V, V, V, V, V, usz }}
|
||||||
def set_consts{k} = {
|
def set_consts{k} = {
|
||||||
# Within-byte transformation
|
|
||||||
tup{t0, t4} = each{{is} => mkV{bv{is}}, split{4, k*iota{8} % 8}}
|
|
||||||
# Overhang from previous 64-bit elements
|
# Overhang from previous 64-bit elements
|
||||||
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
def ix = 64*slice{iota{k},1} // k # bits that overhang within a word
|
||||||
def ib = ix // 8 # byte index
|
def ib = ix // 8 # byte index
|
||||||
@ -448,9 +475,7 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
|
|||||||
@for (r in *V~~r over j to nv) {
|
@for (r in *V~~r over j to nv) {
|
||||||
--q; if (q == 0) { q = wv
|
--q; if (q == 0) { q = wv
|
||||||
# Load and permute bytes
|
# Load and permute bytes
|
||||||
xv = load{*V~~x, i}; ++i
|
xv = get_perm_x{i}; ++i
|
||||||
m4 := V**0xf
|
|
||||||
xv = selV{t0, xv & m4} | selV{t4, V~~([8]u16~~xv>>4) & m4}
|
|
||||||
# Bytes for overhang
|
# Bytes for overhang
|
||||||
xo = V~~((W~~xv & xom) >> (8 - wv))
|
xo = V~~((W~~xv & xom) >> (8 - wv))
|
||||||
xo += xo > V**0
|
xo += xo > V**0
|
||||||
@ -468,12 +493,36 @@ def rep_const_bool_ssse3{wv, x, r, rlen} = { # 2<=wv<=8, wv!=6
|
|||||||
r = V~~((b<<wv) - b)
|
r = V~~((b<<wv) - b)
|
||||||
r |= xo & mkV{255 * (iV%8 == 0)} # overhang
|
r |= xo & mkV{255 * (iV%8 == 0)} # overhang
|
||||||
}
|
}
|
||||||
|
} else { # wv < 16
|
||||||
|
# 9, 11, 13, 15: shared constant computation and loop
|
||||||
|
# Initializers, relying on vl%wv == vl-wv in various ways
|
||||||
|
wV:= V**cast_i{u8,wv}
|
||||||
|
ind0 := (iota{V} < wV) + V**1 # mkV{iV >= k}
|
||||||
|
{m, d} := unaligned_spaced_mask_mod{wv}
|
||||||
|
mask0:= V~~make{W, m, m>>d|m<<(wv-d)}
|
||||||
|
iu:= iota{V} + V**cast_i{u8,vl-wv}; ia:= iu>=V**vl
|
||||||
|
ind_up := iu - (wV&ia)
|
||||||
|
ind_inc:= V**(wv <= vl) - ia
|
||||||
|
mask_sh:= d+d; if (mask_sh >= wv) mask_sh-= wv
|
||||||
|
# State
|
||||||
|
xv:V = V**0; o:=W**0; ind:=xv; mask:=xv
|
||||||
|
i:usz = 0; q:usz = 1
|
||||||
|
@for (r in *V~~r over j to nv) {
|
||||||
|
--q; if (q == 0) { q = wv
|
||||||
|
xv = get_perm_x{i}; ++i
|
||||||
|
ind = ind0
|
||||||
|
mask = mask0
|
||||||
|
} else {
|
||||||
|
ind = selV{ind, ind_up} + ind_inc
|
||||||
|
mask = V~~((W~~mask << (wv - mask_sh)) | (W~~mask >> mask_sh))
|
||||||
|
}
|
||||||
|
# Handle overhang here; won't fit in a single vector
|
||||||
|
b := W~~(selV{xv, ind} & mask)
|
||||||
|
r = V~~((b<<wv) - b)
|
||||||
|
po:= o; o = b>>(64-wv)
|
||||||
|
ro:= [4]u32~~vshl{po, o, 1}
|
||||||
|
r |= V~~(ro + (ro > [4]u32**0))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (wv==2) run{2}
|
|
||||||
else if (wv==3) run{3}
|
|
||||||
else if (wv==4) run{4}
|
|
||||||
else if (wv==8) run{8}
|
|
||||||
else run{tup{5,7}}
|
|
||||||
}
|
}
|
||||||
export{'si_constrep_bool', rep_const_bool{}}
|
export{'si_constrep_bool', rep_const_bool{}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user