diff --git a/src/builtins/slash.c b/src/builtins/slash.c index 4c2020bd..f0e7f133 100644 --- a/src/builtins/slash.c +++ b/src/builtins/slash.c @@ -366,6 +366,7 @@ B grade_bool(B x, usz xia, bool up) { B notx = bit_negate(incG(x)); u64* xp0 = bitarr_ptr(notx); u64* xp1 = xp; + u64 q=xia%64; if (q) { usz e=xia/64; u64 m=((u64)1<promote{eltype{T},c},reverse{xs}} + emit{T, mti{'set',T}, ...p} + } + def broadcast{T, v & isvec{T} & 512==width{T}} = { + emit{T, mti{'set1',T}, promote{eltype{T},v}} + } + def __add{a:T,b:T & 512==width{T}} = emit{T, mti{'add',T}, a, b} +} include './mask' include 'util/tup' @@ -57,7 +69,6 @@ def storeu{p:T, i, v:eltype{T} & *u64==T} = emit{void, 'storeu_u64', p+i, v} def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p} # Assumes w is trimmed, so the last 1 appears at index l-1 -def thresh1{T} = 2 def thresh2{T} = 2 fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1 @@ -67,15 +78,29 @@ fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { } } +def getter{c, V, x} = { + if (c) { + i:u64 = 0 + {} => { v:=load{*V~~x, i}; ++i; v } + } else { + def k = vcount{V} + i := make{V, iota{k}} + ii := V**k + {} => { v:=i; i+=ii; v } + } +} + def thresh2{T==i8 & hasarch{'X86_64'}} = 4 -fn slash{c==1, T==i8 & hasarch{'X86_64'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : void = { +fn slash{c, T==i8 & hasarch{'X86_64'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def U = [16]u8 k1 := U**1 - @for_special_buffered{r,16} (w in *u16~~w, x0 in *U~~x over sum) { + def X = getter{c, U, x} + @for_special_buffered{r,16} (w in *u16~~w to sum) { + x := X{} bm := make{U, 1<<(iota{16}%8)} rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}} bit := rb&bm == bm # Bits of w expanded to a byte each - x := x0&bit + x &= bit dif := k1 + bit # Prefix sum halves of dif @unroll (k to 3) dif += U~~([2]i64~~dif << (8< { v:=load{*[32]T~~x, i}; ++i; s8{v, j} } + } else { + i := make{I, replicate{16,tup{0,16}}} + ii := I**32 + {j} => { v:=i+j; i+=ii; v } + } + @for_special_buffered{r,32} (w in *u32~~w over sum) { def step{k==1} = { # Unused, ~10% slower bit := I~~make{[32]u8, 1<<(iota{32}%8)} sum := I~~(s8{I~~S**w, make{I,iota{32}>>3}}&bit != bit) @@ -171,7 +204,7 @@ fn slash{c==1, T==i8 & hasarch{'AVX2'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : vo tup{sum+ss, max{res, s8{res & mh, io - ss}}} } {_,j16} := step{4} - r16 := s8{x, j16} + r16 := from_ind{j16} store{*[16]T~~r, 0, half{r16, 0}} store{*[16]T~~(r+popc{w&0xffff}), 0, half{r16, 1}} @@ -184,25 +217,26 @@ i64tab:*u32 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<32)}}, tup{0x80808080}, def thresh2{T==i32 & hasarch{'AVX2'}} = 32 def thresh2{T==i64 & hasarch{'AVX2'}} = 8 -fn slash{c==1, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:*T, r:*T, l:u64, sum:u64) : void = { +fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { def tw = width{T} def V = [8]u32 + def X = getter{c, V, x} expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}} def tab = if (tw==32) itab else i64tab - def step{w,i} = { + def step{w} = { pc := popc{w} ind := load{tab, w}; def I = type{ind} s := sel{[16]i8, V~~[width{V}/width{I}]I**ind, expander} if (tw==64) s |= make{V, iota{8}%2} - store{*V~~r, 0, sel{V, load{*V~~x,i}, s}} + store{*V~~r, 0, sel{V, X{}, s}} r+= pc } - @for_special_buffered{r,8} (w in *u8~~wp over i to sum) { + @for_special_buffered{r,8} (w in *u8~~wp to sum) { if (tw==32) { - step{w, i} + step{w} } else { - step{w&0xf, 2*i} - step{w>>4, 2*i+1} + step{w&0xf} + step{w>>4} } } } @@ -211,19 +245,20 @@ def thresh2{T==i8 & hasarch{'AVX512VBMI2'}} = 256 def thresh2{T==i16 & hasarch{'AVX512VBMI2'}} = 128 def thresh2{T==i32 & hasarch{'AVX512F'}} = 64 def thresh2{T==i64 & hasarch{'AVX512F'}} = 16 -fn slash{c==1, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : void = { - def f = match { {_==8}=>'8'; {_==16}=>'16'; {_==32}=>'32'; {_==64}=>'64' } +fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = { + def f = fmtnat def wt = width{T} def vl = 512/wt def V = [vl]T + def X = getter{c, V, x} def wu = max{32,vl} - def load {a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n} - @for (w in *(ty_u{vl})~~w, x in *V~~x over cdiv{l,vl}) { + @for (w in *(ty_u{vl})~~w over cdiv{l,vl}) { def I = ty_u{wu} def emitT{O, name, ...a} = emit{O, merge{'_mm512_',name,'_epi',f{wt}}, ...a} def to_mask{a} = emit{[vl]u1, merge{'_cvtu',f{wu},'_mask',f{vl}}, a} m := to_mask{promote{I,w}} c := popc{w} + x := X{} # The compress-store instruction performs very poorly on Zen4, # and is also a lot worse than the following on Tiger Lake # emitT{void, 'mask_compressstoreu', r, m, x} @@ -236,7 +271,7 @@ fn slash{c==1, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u } export{'si_1slash8' , slash{0, i8 }} -export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh1{i16}} +export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh2{i16}} export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh2{i8 }} export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh2{i16}} export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh2{i32}}