use AVX-512 integer narrow if available

This commit is contained in:
dzaima 2025-04-03 03:00:49 +03:00
parent 528fea645a
commit 732f52630a
2 changed files with 27 additions and 17 deletions

View File

@ -6,13 +6,9 @@ local {
def suf{[_]T} = suf{T}
def pref{w} = merge{'_mm', if (w==128) '' else fmtnat{w}, '_'}
def pref{V=[_]_} = pref{width{V}}
def has512{V=[_]E} = if (width{V}==512) hasarch{'AVX512F'} else hasarch{'AVX512VL'}
def has512{V, post} = has512{V} and hasarch{merge{'AVX512', post}}
def has512e{V=[_]E if width{E}>=32} = has512{V}
def has512e{V=[_]E if width{E}<=16} = has512{V, 'BW'}
def has512e{V=[_]E, post} = has512e{V} and has512{V, post}
def has512 = x86_has512
def has512e = x86_has512e
def intrin = x86_intrin
}
local def re_mask{[l]_, sub} = {
@ -39,3 +35,7 @@ def mask_to_hom{V=[l]_, x:[l](u1)} = emit{V, merge{pref{V},'movm_',suf{V}}, x}
def sel{(ty_u{V}), x:V=[_]E, i:I==(ty_u{V}) if (if (width{E}>8) has512e{V} else has512{V, 'VBMI'})} = emit{V, merge{pref{V}, 'permutexvar_', suf{V}}, i, x}
def multishift{a:[k](u64), i:V=[(k*8)](u8) if has512{V, 'VBMI'}} = emit{V, merge{pref{V}, 'multishift_epi64_epi8'}, i, a}
def narrow{DE, x:[k]SE if isint{DE} and quality{DE}==quality{SE} and x86_has512e{[k]SE}} = {
emit{x86_vec_low{k,DE}, intrin{[k]SE, 'cvtepi', fmtwidth{SE}, '_epi', fmtwidth{DE}}, x}
}

View File

@ -5,23 +5,27 @@ include 'arch/iintrinsic/select'
def v2i{x:T=[_]E} = if(isint{E}) x else re_el{u8, x}
def v2f{x:T=[_]_} = re_el{f32, x}
def v2d{x:T=[_]_} = re_el{f64, x}
def x86_vec_low{n, E} = [__max{128/width{E},n}]E
include './sse2'
include './sse'
include './avx'
include './avx2'
include './avx512'
def x86_has512{V=[_]E} = if (width{V}==512) hasarch{'AVX512F'} else hasarch{'AVX512VL'}
def x86_has512{V, post} = has512{V} and hasarch{merge{'AVX512', post}}
local def has512 = x86_has512
local def fmtwidth{T} = fmtnat{width{T}}
def x86_has512e{V=[_]E if width{E}>=32} = has512{V}
def x86_has512e{V=[_]E if width{E}<=16} = has512{V, 'BW'}
def x86_has512e{V=[_]E, post} = x86_has512e{V} and has512{V, post}
def fmtwidth{T} = fmtnat{width{T}}
local def has_bw{V} = hasarch{match (width{V}) { {128}=>'SSE2'; {256}=>'AVX2'; {512}=>'AVX512BW' }}
local def intrin{V, ...rest} = merge{'_mm', if (width{V}==128) '' else fmtwidth{V}, '_', ...rest}
def x86_intrin = intrin
local def scal_q{q, E} = match (E) {
{(f32)} => 'ps'
{(f64)} => 'pd'
{_} => merge{'ep', q, fmtwidth{E}}
}
local def scal{E} = scal_q{quality{E}, E}
def x86_scal{E} = scal_q{quality{E}, E}
local def intrin_t{V=[_]E, ...rest} = intrin{V, ...rest, '_', scal_q{quality{E}, E}}
local def intrin_i{V=[_]E, ...rest} = intrin{V, ...rest, '_', scal_q{'i', E}}
@ -40,6 +44,12 @@ local def vec_x{V=[k]E} = { # e.g. i64x2 / f32x4 / f64x2
merge{if (isint{E}) 'i' else 'f', fmtwidth{E}, 'x', fmtnat{k}}
}
include './sse2'
include './sse'
include './avx'
include './avx2'
include './avx512'
local def x86_vec_cvt{name, W, D=[_]E, x:X=[_]E} = emit{D, intrin{W, name, vec_l{X}, '_', vec_l{D}}, x}
@ -75,12 +85,12 @@ def mul_sum{2, a:V=[k](i16), b:V if has_bw{V}} = {
emit{[k/2]i32, intrin{V, 'madd_epi16'}, a, b}
}
def low_elts{n, x:V=[k]E} = extract{[__max{128/width{E},n}]E, x, 0}
def x86_low_elts{n, x:V=[k]E} = extract{x86_vec_low{n,E}, x, 0}
def widen{D=[k]DE, x:S=[k0]SE if isint{DE} and quality{DE}==quality{SE} and DE>SE and k<=k0 and hasarch{match (width{D}) {
{128} => 'SSE4.1'
{256} => 'AVX2'
{512} => if (width{DE}<=16) 'AVX512BW' else 'AVX512F'
{512} => x86_has512e{re_el{DE,S}}
}}} = {
emit{D, intrin_i{D, 'cvtep', if (isunsigned{SE}) 'u' else 'i', fmtwidth{SE}}, low_elts{k, x}}
emit{D, intrin_i{D, 'cvtep', if (isunsigned{SE}) 'u' else 'i', fmtwidth{SE}}, x86_low_elts{k, x}}
}