diff --git a/src/singeli/src/slash.singeli b/src/singeli/src/slash.singeli index 54b9dfa6..d568a1a9 100644 --- a/src/singeli/src/slash.singeli +++ b/src/singeli/src/slash.singeli @@ -46,12 +46,8 @@ def maketab{l,w} = { def top = (fold{bind{flat_table,+}, l**iota{2}} - 1)%(1<>6} >> (n&63)) & 1 @for (i to l) { @@ -134,104 +130,10 @@ def topper{T, U, k, x} = { tup{top, inc} } -def thresh{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)} = { - if (fast_where) 1 else 4 -} -fn slash{c, T & hasarch{'X86_64'} & T<=(if (c) i8 else i32)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { - def U = [16]u8 - k1 := U**1 - def X = getter{c, U, x} - def {top, inctop} = topper{T, U, 16, x} - @for_special_buffered{r,16} (w in *u16~~w over i to sum) { - x := X{} - bm := make{U, 1<<(iota{16}%8)} - rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}} - bit := rb&bm == bm # Bits of w expanded to a byte each - x &= bit - dif := k1 + bit - # Prefix sum halves of dif - @unroll (k to 3) dif += U~~([2]i64~~dif << (8< 8 - (extract{[8]u16~~dif, j} >> 8), tup{3,7}} - dif = U~~([2]i64~~dif << 8) - # Shift each value in x down by the corresponding one in dif - b := k1 - @unroll (k to 3) { - m := (dif & b) == b # Mask of positions to shift - y := shr{U, x&m, 1< { gen{...p}; r+=c }, ...par, pc} - if (T==i8) { # 0==tuplen{top} - def st{ins} = emit{void, ins, *[8]u8~~r, x} - each_pc{st, tup{'_mm_storel_pi','_mm_storeh_pi'}} - } else { - def st{k, v:V} = store{*V~~r, k, v} - def st{v} = if (T==i16) st{0, v} - else each{st, iota{2}, unpack{v, tupsel{1,top}}} - each_pc{st, unpack{[16]i8~~x, tupsel{0,top}}} - } - inctop{i, top} - } -} - -def thresh{c, T==i8 & hasarch{'AVX2'}} = 32 -fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { - def I = [32]i8 - def S = [8]u32 - def s8 = bind{sel,[16]u8} - def mI{t} = make{I, merge{t,t}} - io := mI{iota{16}} - tr4x4 := mI{join{flip{split{4,iota{16}}}}} - - sumtab := mI{flat_table{{...a}=>fold{+,a}, ... 4**iota{2}} - 4} - def ind4{b} = shiftright{indices{reverse{b}}-iota{fold{+,b}},4**0} - def ind2x2{...b} = base{4, ind4{b}} - itab := mI{flat_table{ind2x2, ... 4**iota{2}}} - - def from_ind = if (c) { - i:u64 = 0 - {j} => { v:=load{*[32]T~~x, i}; ++i; s8{v, j} } - } else { - i := make{I, replicate{16,tup{0,16}}} - ii := I**32 - {j} => { v:=i+j; i+=ii; v } - } - @for_special_buffered{r,32} (w in *u32~~w over sum) { - def step{k==1} = { # Unused, ~10% slower - bit := I~~make{[32]u8, 1<<(iota{32}%8)} - sum := I~~(s8{I~~S**w, make{I,iota{32}>>3}}&bit != bit) - tup{sum + shl{[16]u8, sum, 1}, io - sum} - } - def step{k==2} = { - wv := I~~(S**w >> make{S,4*iota{8}}) & I**0xf - sum:= s8{sumtab, wv} - ws := s8{itab, s8{wv, mI{4*(iota{16}%4)}}} - w4 := io + s8{I~~(S~~ws >> make{S,2*(iota{8}%4)}) & I**3, tr4x4} - tup{shl{[16]u8, sum, 3}, w4} - } - def step{k & k>2} = { - def h = k-1 - {sum, res} := step{h} - ik := mI{zlow{k,iota{16}} + (1<>h & 1)} - ss := s8{sum, ik} - tup{sum+ss, max{res, s8{res & mh, io - ss}}} - } - {_,j16} := step{4} - r16 := from_ind{j16} - - store{*[16]T~~r, 0, half{r16, 0}} - store{*[16]T~~(r+popc{w&0xffff}), 0, half{r16, 1}} - r += popc{w} - } -} - itab_4_16:*u64 = maketab{4,16} -def thresh{c==0, T==i8 & use_table} = 32 -def thresh{c==0, T==i16} = 16 -fn slash{c==0, T & (if (T==i8) use_table else T==i16)}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { +def thresh{c==0, T==i8 } = 32 +def thresh{c==0, T==i16} = 16 +fn slash{c==0, T & T<=i16}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def tw = width{T} def n = 64/tw def tab = if (tw==8) itab else itab_4_16 @@ -251,9 +153,9 @@ fn slash{c==0, T & (if (T==i8) use_table else T==i16)}(w:*u64, x:arg{c,T}, r:*T, } } -def thresh{c==0, T==i16 & hasarch{'X86_64'} & use_table} = 32 -def thresh{c==0, T==i32 & hasarch{'X86_64'} & use_table} = 16 -fn slash{c==0, T & hasarch{'X86_64'} & use_table & i16<=T & T<=i32}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { +def thresh{c==0, T==i16 & hasarch{'X86_64'}} = 32 +def thresh{c==0, T==i32 & hasarch{'X86_64'}} = 16 +fn slash{c==0, T & hasarch{'X86_64'} & i16<=T & T<=i32}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def I = [16]i8 j := I**(if (T==i16) 0 else cast_i{i8,x}) def {top, inctop} = topper{T, I, 8, x} @@ -270,9 +172,9 @@ fn slash{c==0, T & hasarch{'X86_64'} & use_table & i16<=T & T<=i32}(w:*u64, x:ar } } -def thresh{c==1, T==i8 & hasarch{'SSSE3'} & use_table} = 64 -def thresh{c==1, T==i16 & hasarch{'SSSE3'} & use_table} = 32 -fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'} & use_table}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { +def thresh{c==1, T==i8 & hasarch{'SSSE3'}} = 64 +def thresh{c==1, T==i16 & hasarch{'SSSE3'}} = 32 +fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'}}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def tw = width{T} def V = [16]i8 @for_special_buffered{r,8} (w in *u8~~wp over i to sum) { @@ -288,9 +190,9 @@ fn slash{c==1, T & T<=i16 & hasarch{'SSSE3'} & use_table}(wp:*u64, x:arg{c,T}, r } i64tab:*u32 = (maketab{4,8}*2)%(1<<32) -def thresh{c, T==i32 & hasarch{'AVX2'} & use_table} = 32 -def thresh{c, T==i64 & hasarch{'AVX2'} } = 8 -fn slash{c, T & hasarch{'AVX2'} & (if (T==i32) use_table else T==i64)}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { +def thresh{c, T==i32 & hasarch{'AVX2'}} = 32 +def thresh{c, T==i64 & hasarch{'AVX2'}} = 8 +fn slash{c, T & hasarch{'AVX2'} & T>=i32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def tw = width{T} def V = [8]u32 expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}}