Extend all Singeli Compress methods to do Where

This commit is contained in:
Marshall Lochbaum 2023-07-16 20:18:54 -04:00
parent 4415869496
commit 3bd8d1de68
2 changed files with 57 additions and 21 deletions

View File

@ -366,6 +366,7 @@ B grade_bool(B x, usz xia, bool up) {
B notx = bit_negate(incG(x));
u64* xp0 = bitarr_ptr(notx);
u64* xp1 = xp;
u64 q=xia%64; if (q) { usz e=xia/64; u64 m=((u64)1<<q)-1; xp0[e]&=m; xp1[e]&=m; }
if (!up) { u64* t=xp1; xp1=xp0; xp0=t; }
#define SI_GRADE(W) \
i##W* rp = m_tyarrv(&r, W/8, xia, t_i##W##arr); \

View File

@ -14,6 +14,18 @@ if (hasarch{'AVX2'}) {
include './avx'
include './avx2'
}
if (hasarch{'AVX512F'}) {
local def mti{s,T} = merge{'_mm512_',s,'_epi',fmtnat{elwidth{T}}}
def load{a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n}
def make{T, xs & 512==width{T} & tuplen{xs}==vcount{T}} = {
def p = each{{c}=>promote{eltype{T},c},reverse{xs}}
emit{T, mti{'set',T}, ...p}
}
def broadcast{T, v & isvec{T} & 512==width{T}} = {
emit{T, mti{'set1',T}, promote{eltype{T},v}}
}
def __add{a:T,b:T & 512==width{T}} = emit{T, mti{'add',T}, a, b}
}
include './mask'
include 'util/tup'
@ -57,7 +69,6 @@ def storeu{p:T, i, v:eltype{T} & *u64==T} = emit{void, 'storeu_u64', p+i, v}
def loadu{p:T & *u64==T} = emit{eltype{T}, 'loadu_u64', p}
# Assumes w is trimmed, so the last 1 appears at index l-1
def thresh1{T} = 2
def thresh2{T} = 2
fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def bitp_get{arr, n} = (load{arr,n>>6} >> (n&63)) & 1
@ -67,15 +78,29 @@ fn slash{c, T}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
}
}
def getter{c, V, x} = {
if (c) {
i:u64 = 0
{} => { v:=load{*V~~x, i}; ++i; v }
} else {
def k = vcount{V}
i := make{V, iota{k}}
ii := V**k
{} => { v:=i; i+=ii; v }
}
}
def thresh2{T==i8 & hasarch{'X86_64'}} = 4
fn slash{c==1, T==i8 & hasarch{'X86_64'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : void = {
fn slash{c, T==i8 & hasarch{'X86_64'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def U = [16]u8
k1 := U**1
@for_special_buffered{r,16} (w in *u16~~w, x0 in *U~~x over sum) {
def X = getter{c, U, x}
@for_special_buffered{r,16} (w in *u16~~w to sum) {
x := X{}
bm := make{U, 1<<(iota{16}%8)}
rb := make{U, replicate{8,each{bind{cast_i,u8},tup{w,w>>8}}}}
bit := rb&bm == bm # Bits of w expanded to a byte each
x := x0&bit
x &= bit
dif := k1 + bit
# Prefix sum halves of dif
@unroll (k to 3) dif += U~~([2]i64~~dif << (8<<k))
@ -85,7 +110,8 @@ fn slash{c==1, T==i8 & hasarch{'X86_64'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) :
b := k1
@unroll (k to 3) {
m := (dif & b) == b # Mask of positions to shift
x = shr{U, x&m, 1<<k} | (x&~m)
y := shr{U, x&m, 1<<k}
x = (if (c) (x&~m)|y else max{x, y})
dif = max{dif, shr{U, dif&m, 1<<k}}
b += b
}
@ -105,7 +131,6 @@ def tab{n,l} = if (n==0) tup{0} else {
c16lut:*u64 = tab{4,16}
def vgLoad{p:T, i & T == *u64} = emit{eltype{T}, 'vg_loadLUT64', p, i}
def thresh1{T==i16 & hasarch{'BMI2'}} = 8
def thresh2{T==i8 & hasarch{'BMI2'}} = 32
def thresh2{T==i16 & hasarch{'BMI2'}} = 16
fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
@ -136,7 +161,7 @@ fn slash{c, T & hasarch{'BMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : voi
}
def thresh2{T==i8 & hasarch{'AVX2'}} = 32
fn slash{c==1, T==i8 & hasarch{'AVX2'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : void = {
fn slash{c, T==i8 & hasarch{'AVX2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def I = [32]i8
def S = [8]u32
def s8 = bind{sel,[16]u8}
@ -149,7 +174,15 @@ fn slash{c==1, T==i8 & hasarch{'AVX2'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : vo
def ind2x2{...b} = base{4, ind4{b}}
itab := mI{flat_table{ind2x2, ... 4**iota{2}}}
@for_special_buffered{r,32} (w in *u32~~w, x in *[32]T~~x over sum) {
def from_ind = if (c) {
i:u64 = 0
{j} => { v:=load{*[32]T~~x, i}; ++i; s8{v, j} }
} else {
i := make{I, replicate{16,tup{0,16}}}
ii := I**32
{j} => { v:=i+j; i+=ii; v }
}
@for_special_buffered{r,32} (w in *u32~~w over sum) {
def step{k==1} = { # Unused, ~10% slower
bit := I~~make{[32]u8, 1<<(iota{32}%8)}
sum := I~~(s8{I~~S**w, make{I,iota{32}>>3}}&bit != bit)
@ -171,7 +204,7 @@ fn slash{c==1, T==i8 & hasarch{'AVX2'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : vo
tup{sum+ss, max{res, s8{res & mh, io - ss}}}
}
{_,j16} := step{4}
r16 := s8{x, j16}
r16 := from_ind{j16}
store{*[16]T~~r, 0, half{r16, 0}}
store{*[16]T~~(r+popc{w&0xffff}), 0, half{r16, 1}}
@ -184,25 +217,26 @@ i64tab:*u32 = fold{{t,k} => join{each{tup,t,k+(t<<8)%(1<<32)}}, tup{0x80808080},
def thresh2{T==i32 & hasarch{'AVX2'}} = 32
def thresh2{T==i64 & hasarch{'AVX2'}} = 8
fn slash{c==1, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:*T, r:*T, l:u64, sum:u64) : void = {
fn slash{c, T & hasarch{'AVX2'} & width{T}>=32}(wp:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def tw = width{T}
def V = [8]u32
def X = getter{c, V, x}
expander := make{[32]u8, merge{...each{{i}=>tup{i, ... 3**128}, iota{8}>>lb{tw/32}}}}
def tab = if (tw==32) itab else i64tab
def step{w,i} = {
def step{w} = {
pc := popc{w}
ind := load{tab, w}; def I = type{ind}
s := sel{[16]i8, V~~[width{V}/width{I}]I**ind, expander}
if (tw==64) s |= make{V, iota{8}%2}
store{*V~~r, 0, sel{V, load{*V~~x,i}, s}}
store{*V~~r, 0, sel{V, X{}, s}}
r+= pc
}
@for_special_buffered{r,8} (w in *u8~~wp over i to sum) {
@for_special_buffered{r,8} (w in *u8~~wp to sum) {
if (tw==32) {
step{w, i}
step{w}
} else {
step{w&0xf, 2*i}
step{w>>4, 2*i+1}
step{w&0xf}
step{w>>4}
}
}
}
@ -211,19 +245,20 @@ def thresh2{T==i8 & hasarch{'AVX512VBMI2'}} = 256
def thresh2{T==i16 & hasarch{'AVX512VBMI2'}} = 128
def thresh2{T==i32 & hasarch{'AVX512F'}} = 64
def thresh2{T==i64 & hasarch{'AVX512F'}} = 16
fn slash{c==1, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:*T, r:*T, l:u64, sum:u64) : void = {
def f = match { {_==8}=>'8'; {_==16}=>'16'; {_==32}=>'32'; {_==64}=>'64' }
fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64, x:arg{c,T}, r:*T, l:u64, sum:u64) : void = {
def f = fmtnat
def wt = width{T}
def vl = 512/wt
def V = [vl]T
def X = getter{c, V, x}
def wu = max{32,vl}
def load {a:T, n & 512==width{eltype{T}}} = emit{eltype{T}, '_mm512_loadu_si512', a+n}
@for (w in *(ty_u{vl})~~w, x in *V~~x over cdiv{l,vl}) {
@for (w in *(ty_u{vl})~~w over cdiv{l,vl}) {
def I = ty_u{wu}
def emitT{O, name, ...a} = emit{O, merge{'_mm512_',name,'_epi',f{wt}}, ...a}
def to_mask{a} = emit{[vl]u1, merge{'_cvtu',f{wu},'_mask',f{vl}}, a}
m := to_mask{promote{I,w}}
c := popc{w}
x := X{}
# The compress-store instruction performs very poorly on Zen4,
# and is also a lot worse than the following on Tiger Lake
# emitT{void, 'mask_compressstoreu', r, m, x}
@ -236,7 +271,7 @@ fn slash{c==1, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u
}
export{'si_1slash8' , slash{0, i8 }}
export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh1{i16}}
export{'si_1slash16', slash{0, i16}}; export{'si_thresh_1slash16', u64~~thresh2{i16}}
export{'si_2slash8' , slash{1, i8 }}; export{'si_thresh_2slash8' , u64~~thresh2{i8 }}
export{'si_2slash16', slash{1, i16}}; export{'si_thresh_2slash16', u64~~thresh2{i16}}
export{'si_2slash32', slash{1, i32}}; export{'si_thresh_2slash32', u64~~thresh2{i32}}