Remove table-less where/compress methods, as they're not competitive

This commit is contained in:
Marshall Lochbaum 2023-07-19 07:22:55 -04:00
parent 64d65ae837
commit ba837ba01b

View File

@ -46,12 +46,8 @@ def maketab{l,w} = {
def top = (fold{bind{flat_table,+}, l**iota{2}} - 1)%(1<<w)
top<<(l*w-w) | bot # Overlaps for all-1 value only
}
if (1) {
def use_table = 1
itab:*u64 = maketab{8,8}
} else {
def use_table = 0
}
# 2KB table shared by many methods
itab:*u64 = maketab{8,8}
# Recover popcount, for when POPCNT isn't there
def has_popc = hasarch{'POPCNT'}
@ -95,8 +91,8 @@ def for_special_buffered{r, write_len}{vars,begin,sum,iter} = {
}
# Assumes w is trimmed, so the last 1 appears at index l-1
def fast_where = hasarch{'X86_64'} & use_table
def thresh{c, T} = if (fast_where) 1 else 2
# Unused because an index buffer and select is faster
def thresh{c, T} = 1
fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1
@for (i to l) {
@ -134,104 +130,10 @@ def topper{T, U, k, x} = {
tup{top, inc}
}
def thresh{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)} = {
if (fast_where) 1 else 4
}
fn slash{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def U = [16]u8
k1 := U**1
def X = getter{c, U, x}
def {top, inctop} = topper{T, U, 16, x}
@for_special_buffered{r,16} (w in *u16~~w over i to sum) {
x := X{}
bm := make{U, 1<<(iota{16}%8)}
rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}}
bit := rb&bm == bm # Bits of w expanded to a byte each
x &= bit
dif := k1 + bit
# Prefix sum halves of dif
@unroll (k to 3) dif += U~~([2]i64~~dif << (8<<k))
pc := each{{j} => 8 - (extract{[8]u16~~dif, j} >> 8), tup{3,7}}
dif = U~~([2]i64~~dif << 8)
# Shift each value in x down by the corresponding one in dif
b := k1
@unroll (k to 3) {
m := (dif & b) == b # Mask of positions to shift
y := shr{U, x&m, 1<<k}
x = (if (c) (x&~m)|y else max{x, y})
dif = max{dif, shr{U, dif&m, 1<<k}}
b += b
}
def each_pc{gen, ...par} = each{{...p,c} => { gen{...p}; r+=c }, ...par, pc}
if (T==i8) { # 0==tuplen{top}
def st{ins} = emit{void, ins, *[8]u8~~r, x}
each_pc{st, tup{'_mm_storel_pi','_mm_storeh_pi'}}
} else {
def st{k, v:V} = store{*V~~r, k, v}
def st{v} = if (T==i16) st{0, v}
else each{st, iota{2}, unpack{v, tupsel{1,top}}}
each_pc{st, unpack{[16]i8~~x, tupsel{0,top}}}
}
inctop{i, top}
}
}
def thresh{c, T==i8 & hasarch{'AVX2'}} = 32
fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [32]i8
def S = [8]u32
def s8 = bind{sel,[16]u8}
def mI{t} = make{I, merge{t,t}}
io := mI{iota{16}}
tr4x4 := mI{join{flip{split{4,iota{16}}}}}
sumtab := mI{flat_table{{...a}=>fold{+,a}, ... 4**iota{2}} - 4}
def ind4{b} = shiftright{indices{reverse{b}}-iota{fold{+,b}},4**0}
def ind2x2{...b} = base{4, ind4{b}}
itab := mI{flat_table{ind2x2, ... 4**iota{2}}}
def from_ind = if (c) {
i:u64 = 0
{j} => { v:=load{*[32]T~~x, i}; ++i; s8{v, j} }
} else {
i := make{I, replicate{16,tup{0,16}}}
ii := I**32
{j} => { v:=i+j; i+=ii; v }
}
@for_special_buffered{r,32} (w in *u32~~w over sum) {
def step{k==1} = { # Unused, ~10% slower
bit := I~~make{[32]u8, 1<<(iota{32}%8)}
sum := I~~(s8{I~~S**w, make{I,iota{32}>>3}}&bit != bit)
tup{sum + shl{[16]u8, sum, 1}, io - sum}
}
def step{k==2} = {
wv := I~~(S**w >> make{S,4*iota{8}}) & I**0xf
sum:= s8{sumtab, wv}
ws := s8{itab, s8{wv, mI{4*(iota{16}%4)}}}
w4 := io + s8{I~~(S~~ws >> make{S,2*(iota{8}%4)}) & I**3, tr4x4}
tup{shl{[16]u8, sum, 3}, w4}
}
def step{k & k>2} = {
def h = k-1
{sum, res} := step{h}
ik := mI{zlow{k,iota{16}} + (1<<h - 1)}
mh := mI{-(iota{16}>>h & 1)}
ss := s8{sum, ik}
tup{sum+ss, max{res, s8{res & mh, io - ss}}}
}
{_,j16} := step{4}
r16 := from_ind{j16}
store{*[16]T~~r, 0, half{r16, 0}}
store{*[16]T~~(r+popc{w&0xffff}), 0, half{r16, 1}}
r += popc{w}
}
}
itab_4_16:*u64 = maketab{4,16}
def thresh{c==0, T==i8 & use_table} = 32
def thresh{c==0, T==i16} = 16
fn slash{c==0, T & (if (T==i8) use_table else T==i16)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def thresh{c==0, T==i8 } = 32
def thresh{c==0, T==i16} = 16
fn slash{c==0, T & T<=i16}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def n = 64/tw
def tab = if (tw==8) itab else itab_4_16
@ -251,9 +153,9 @@ fn slash{c==0, T & (if (T==i8) use_table else T==i16)}(w:*u64, x:arg{c,T}, r:*T,
}
}
def thresh{c==0, T==i16 & hasarch{'X86_64'} & use_table} = 32
def thresh{c==0, T==i32 & hasarch{'X86_64'} & use_table} = 16
fn slash{c==0, T & hasarch{'X86_64'} & use_table & i16<=T & T<=i32}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def thresh{c==0, T==i16 & hasarch{'X86_64'}} = 32
def thresh{c==0, T==i32 & hasarch{'X86_64'}} = 16
fn slash{c==0, T & hasarch{'X86_64'} & i16<=T & T<=i32}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [16]i8
j := I**(if (T==i16) 0 else cast_i{i8,x})
def {top, inctop} = topper{T, I, 8, x}
@ -270,9 +172,9 @@ fn slash{c==0, T & hasarch{'X86_64'} & use_table & i16<=T & T<=i32}(w:*u64, x:ar
}
}
def thresh{c==1, T==i8 & hasarch{'SSSE3'} & use_table} = 64
def thresh{c==1, T==i16 & hasarch{'SSSE3'} & use_table} = 32
fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'} & use_table}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def thresh{c==1, T==i8 & hasarch{'SSSE3'}} = 64
def thresh{c==1, T==i16 & hasarch{'SSSE3'}} = 32
fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'}}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def V = [16]i8
@for_special_buffered{r,8} (w in *u8~~wp over i to sum) {
@ -288,9 +190,9 @@ fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'} & use_table}(wp:*u64, x:arg{c,T}, r
}
i64tab:*u32 = (maketab{4,8}*2)%(1<<32)
def thresh{c, T==i32 & hasarch{'AVX2'} & use_table} = 32
def thresh{c, T==i64 & hasarch{'AVX2'} } = 8
fn slash{c, T & hasarch{'AVX2'} & (if (T==i32) use_table else T==i64)}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def thresh{c, T==i32 & hasarch{'AVX2'}} = 32
def thresh{c, T==i64 & hasarch{'AVX2'}} = 8
fn slash{c, T & hasarch{'AVX2'} & T>=i32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def V = [8]u32
expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}}