Merge pull request #101 from mlochbaum/avx512scan

AVX-512 xor-scan
This commit is contained in:
dzaima 2023-12-30 22:07:23 +02:00 committed by GitHub
commit 9931c1756c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 11 deletions

@ -1 +1 @@
Subproject commit 528faaf9e2a7f4f3434365bcd91d6c18c87c4f08
Subproject commit 5f9cbd46c265491ff167a5d9377d1462539dbdd8

View File

@ -0,0 +1,33 @@
local {
def ismask{M} = if (isvec{M}) u1==eltype{M} else 0
def suf{T} = {
if (isfloat{T}) (if (width{T}==32) 'ps' else 'pd')
else merge{'epi', fmtnat{width{T}}}
}
def suf{V & isvec{V}} = suf{eltype{V}}
def pref{w} = merge{'_mm', if (w==128) '' else fmtnat{w}, '_'}
def pref{V & isvec{V}} = pref{width{V}}
}
local def re_mask{M, sub} = {
def l = vcount{M}; def w = max{32,l}
sub{fmtnat{l}, fmtnat{w}, ty_u{w}}
}
def reinterpret{M, a:T & ismask{M} & width{T}==width{M}} = {
re_mask{M, {l,w,W} => emit{M, merge{'_cvtu',w,'_mask',l}, promote{W, a}}}
}
def reinterpret{T, a:M & ismask{M} & width{T}==width{M}} = {
re_mask{M, {l,w,W} => cast_i{T, emit{W, merge{'_cvtmask',l,'_u',w}, a}}}
}
def maskStore{p:*V, m:M, v:V & ismask{M} & isvec{V} & vcount{M}==vcount{V}} = {
emit{void, merge{pref{V}, 'mask_storeu_', suf{V}}, p, m, v}
}
def topMaskReg{x:V} = emit{[vcount{V}]u1, merge{pref{V},'mov',suf{V},'_mask'}, x}
def topMask{x:T & isvec{T} & 512==width{T}} = ty_u{vcount{T}}~~topMaskReg{x}
def homMask{x:T & isvec{T} & 512==width{T}} = topMask{x}
def maskToHom{T, x:M & ismask{M} & isvec{T} & vcount{M}==vcount{T}} = {
emit{T, merge{pref{T},'movm_',suf{T}}, x}
}

View File

@ -166,7 +166,7 @@ def lvec{T, n, w & isvec{T} & vcount{T}==n & elwidth{T}==w} = 1
def {
absu,andAllZero,andnz,b_getBatch,clmul,cvt,extract,fold_addw,half,
homAll,homAny,homBlend,homMask,homMaskStore,homMaskStoreF,loadBatchBit,
loadLow,make,mulw,mulh,narrow,narrowPair,packHi,packLo,packQ,pair,pdep,
loadLow,make,maskStore,maskToHom,mulw,mulh,narrow,narrowPair,packHi,packLo,packQ,pair,pdep,
pext,popcRand,sel,shl,shr,shuf,shuf16Hi,shuf16Lo,shufHalves,storeLow,
topBlend,topMask,topMaskStore,topMaskStoreF,unord,unpackQ,vfold,widen,
zip,zipHi,zipLo

View File

@ -1,5 +1,8 @@
include './base'
include './clmul'
if (hasarch{'X86_64'}) {
if (hasarch{'PCLMUL'}) include './clmul'
if (hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}) include './avx512'
}
include './mask'
include './f64'
include './scan_common'
@ -77,7 +80,7 @@ fn scan_neq{}(p:u64, x:*u64, r:*u64, nw:u64) : void = {
p = -(r>>63) # repeat sign bit
}
}
fn clmul_scan_ne_any{..._ & hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
fn clmul_scan_ne_any{& hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words:u64, mark:u64) : void = {
def V = [2]u64
m := V**mark
def xor64{a, i, carry} = { # carry is 64-bit broadcasted current total
@ -98,9 +101,33 @@ fn clmul_scan_ne_any{..._ & hasarch{'PCLMUL'}}(x:*void, r:*void, init:u64, words
storeLow{rv+e, 64, clmul{loadLow{xv+e, 64}, m, 0} ^ c}
}
}
fn scan_neq{..._ & hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
fn scan_neq{& hasarch{'PCLMUL'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
clmul_scan_ne_any{}(*void~~x, *void~~r, init, nw, -(u64~~1))
}
fn scan_neq{& hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u64, nw:u64) : void = {
def V = [8]u64
def sse{a} = make{[2]u64, a, 0}
carry := sse{init}
# xor-scan on bytes
xmat := V**base{256, 1<<(8-iota{8}) - 1}
def xor8 = emit{V, '_mm512_gf2p8affine_epi64_epi8', ., xmat, 0}
# Exclusive xor-scan on one word
def exor64 = clmul{., sse{1<<64 - 2}, 0}
@for (xv in *V~~x, rv in *V~~r over i to cdiv{nw,vcount{V}}) {
x8 := xor8{xv}
hb := sse{topMask{[64]u8~~x8}}
xh := exor64{hb} # Exclusive xor of high bits (xh ^ hb for inclusive)
xc := xh ^ carry
v := x8 ^ V~~maskToHom{[64]u8, [64]u1~~extract{xc,0}}
carry = (xc ^ hb) ^ shuf{[4]u32, xh, 4b3232}
rem:= nw - 8*i
if (rem < 8) {
maskStore{*V~~r+i, [8]u1~~(~(u8~~0xff<<rem)), v}
return{}
}
rv = v
}
}
export{'si_scan_ne', scan_neq{}}
# Boolean cumulative sum

View File

@ -2,6 +2,7 @@ include './base'
if (hasarch{'X86_64'}) {
if (hasarch{'PCLMUL'}) include './clmul'
if (hasarch{'BMI2'}) include './bmi2'
if (hasarch{'AVX512F'}) include './avx512'
}
include './mask'
include 'util/tup'
@ -244,21 +245,19 @@ fn slash{c, T & hasarch{if (width{T}>=32) 'AVX512F' else 'AVX512VBMI2'}}(w:*u64,
def vl = 512/wt
def V = [vl]T
def X = getter{c, V, x}
def wu = max{32,vl}
def I = ty_u{vl}
@for (w in *(ty_u{vl})~~w over cdiv{l,vl}) {
def I = ty_u{wu}
def emitT{O, name, ...a} = emit{O, merge{'_mm512_',name,'_epi',f{wt}}, ...a}
def to_mask{a} = emit{[vl]u1, merge{'_cvtu',f{wu},'_mask',f{vl}}, a}
m := to_mask{promote{I,w}}
m := [vl]u1~~w
c := popc{w}
x := X{}
# The compress-store instruction performs very poorly on Zen4,
# and is also a lot worse than the following on Tiger Lake
# emitT{void, 'mask_compressstoreu', r, m, x}
cs := cast_i{I,promote{i64,1}<<(c%64) - 1}
if (wu==64) cs -= cast_i{I,c}>>6
if (vl==64) cs -= cast_i{I,c}>>6
v := emitT{V, 'mask_compress', x, m, x}
emitT{void, 'mask_storeu', r, to_mask{cs}, v}
maskStore{*V~~r, [vl]u1~~cs, v}
r += c
}
}