move x86 int widen into x86.singeli

+ AVX-512 defs for it & (undef|zero)_promote; + add extract{V, x:V0, k}
This commit is contained in:
dzaima 2025-04-03 01:28:16 +03:00
parent dbe97cf0e0
commit 78359afc71
6 changed files with 55 additions and 26 deletions

View File

@ -1,6 +1,3 @@
def undef_promote{T=[_]E, x:X=[_]E if w128{X} and w256{T}} = T~~emit{[32]u8, '_mm256_castsi128_si256', v2i{x}}
def zero_promote{T=[_]E, x:X=[_]E if w128{X} and w256{T}} = T~~emit{[32]u8, '_mm256_zextsi128_si256', v2i{x}}
def load{V=[_]E, ptr:*E, vl if w256{V} and vl*width{E}<=128} = undef_promote{V, load{n_h{V}, ptr, vl}}
def store{ptr:*E, x:V=[k]E, vl if w256{V} and vl*width{E}<=128} = store{ptr, half{x, 0}, vl}
@ -9,8 +6,6 @@ def rsqrtE{a:T==[8]f32} = emit{T, '_mm256_rsqrt_ps', a}
def rcpE{a:T==[8]f32} = emit{T, '_mm256_rcp_ps', a}
# conversion
def half{x:T, i if w256{T} and knum{i}} = n_h{T} ~~ emit{[8]i16, '_mm256_extracti128_si256', v2i{x}, i}
def half{x:T, 0 if w256{T}} = n_h{T} ~~ emit{[8]i16, '_mm256_castsi256_si128', v2i{x}}
def pair{a:T,b:T if w128{T}} = n_d{T} ~~ emit{[8]i32, '_mm256_setr_m128i', a, b}
def widen{T==[4]f64, x:([4]i32)} = emit{T, '_mm256_cvtepi32_pd', x}

View File

@ -50,13 +50,6 @@ def all_top{x:T if w256i{T,16}} = all_hom{[16]i16~~x < [16]i16**0}
# conversion
def widen{T==[16]u16, x:X==[16]u8} = emit{T, '_mm256_cvtepu8_epi16', x}; def widen{T==[16]i16, x:X==[16]i8} = emit{T, '_mm256_cvtepi8_epi16', x}
def widen{T==[ 8]u32, x:X==[16]u8} = emit{T, '_mm256_cvtepu8_epi32', x}; def widen{T==[ 8]i32, x:X==[16]i8} = emit{T, '_mm256_cvtepi8_epi32', x}
def widen{T==[ 8]u32, x:X==[8]u16} = emit{T, '_mm256_cvtepu16_epi32', x}; def widen{T==[ 8]i32, x:X==[8]i16} = emit{T, '_mm256_cvtepi16_epi32', x}
def widen{T==[ 4]u64, x:X==[16]u8} = emit{T, '_mm256_cvtepu8_epi64', x}; def widen{T==[ 4]i64, x:X==[16]i8} = emit{T, '_mm256_cvtepi8_epi64', x}
def widen{T==[ 4]u64, x:X==[8]u16} = emit{T, '_mm256_cvtepu16_epi64', x}; def widen{T==[ 4]i64, x:X==[8]i16} = emit{T, '_mm256_cvtepi16_epi64', x}
def widen{T==[ 4]u64, x:X==[4]u32} = emit{T, '_mm256_cvtepu32_epi64', x}; def widen{T==[ 4]i64, x:X==[4]i32} = emit{T, '_mm256_cvtepi32_epi64', x}
def narrow{T, x:X if w256i{X,32} and width{T}==8} = {
a:= packQ{x, x}
b:= packQ{a, a}

View File

@ -25,6 +25,8 @@ def isprim{T} = istype{T} and same{typekind{T}, 'primitive'}
def isptr {T} = istype{T} and same{typekind{T}, 'pointer'}
def any_num = match { {x:T}=>isprim{T}; {x} => knum{x} }
def any_int = match { {x:T}=>isint{T}; {x} => knum{x} and (x>>0) == x }
def int_idx{_, _} = 0
def int_idx{k if knum{k}, l} = (k>>0)==k and k>=0 and k<l
def elwidth{T} = width{eltype{T}}
def reinterpret{T, x:T} = x
@ -286,6 +288,7 @@ def narrow{T, x:[_]T} = x
def undef_promote{T, x:T} = x
def zero_promote{T, x:T} = x
def cvt{T, x:[_]T} = x
def extract{V=[k]E, x:V, 0} = x
def broadcast{V=[_]T, v} = vec_broadcast{V, if (knum{v}) v else promote{T,v}}
def make{V=[_]_, ...xs} = vec_make{V, ...xs}

View File

@ -22,8 +22,8 @@ def bqn_or{a, b} = (a+b)-(a*b)
# def arithChk1{(__sub), M, w:T, x:T, r:T} = tup{'any_top', M{(w^x) & (w^r)}}
def arithChk1{(__add), M, w:T=[_]E, x:T, r:T} = tup{'anyne', adds{w,x}, r}
def arithChk1{(__sub), M, w:T=[_]E, x:T, r:T} = tup{'anyne', subs{w,x}, r}
def arithChk1{(__add), M, w:T, x:T, r:T if hasarch{'X86_64'}} = tup{'any_top', M{ty_s{w>r} ^ x}}
def arithChk1{(__sub), M, w:T, x:T, r:T if hasarch{'X86_64'}} = tup{'any_top', M{ty_s{x>w} ^ r}}
def arithChk1{(__add), M, w:T, x:T, r:T if hasarch{'X86_64'} and width{T}<=256} = tup{'any_top', M{ty_s{w>r} ^ x}}
def arithChk1{(__sub), M, w:T, x:T, r:T if hasarch{'X86_64'} and width{T}<=256} = tup{'any_top', M{ty_s{x>w} ^ r}}

View File

@ -7,14 +7,6 @@ def packs{a:T,b:T if hasarch{'SSE4.1'} and T==[4]u32} = emit{[ 8]u16, '_mm_packu
def and_bit_none{x:T, y:T if hasarch{'SSE4.1'} and w128i{T}} = emit{u1, '_mm_testz_si128', x, y}
# conversion
def widen{T==[8]u16, x:X==[16]u8 if hasarch{'SSE4.1'}} = emit{T, '_mm_cvtepu8_epi16', x}; def widen{T==[8]i16, x:X if hasarch{'SSE4.1'} and X==[16]i8} = emit{T, '_mm_cvtepi8_epi16', x}
def widen{T==[4]u32, x:X==[16]u8 if hasarch{'SSE4.1'}} = emit{T, '_mm_cvtepu8_epi32', x}; def widen{T==[4]i32, x:X if hasarch{'SSE4.1'} and X==[16]i8} = emit{T, '_mm_cvtepi8_epi32', x}
def widen{T==[4]u32, x:X==[8]u16 if hasarch{'SSE4.1'}} = emit{T, '_mm_cvtepu16_epi32', x}; def widen{T==[4]i32, x:X if hasarch{'SSE4.1'} and X==[8]i16} = emit{T, '_mm_cvtepi16_epi32', x}
def widen{T==[2]u64, x:X==[16]u8 if hasarch{'SSE4.1'}} = emit{T, '_mm_cvtepu8_epi64', x}; def widen{T==[2]i64, x:X if hasarch{'SSE4.1'} and X==[16]i8} = emit{T, '_mm_cvtepi8_epi64', x}
def widen{T==[2]u64, x:X==[8]u16 if hasarch{'SSE4.1'}} = emit{T, '_mm_cvtepu16_epi64', x}; def widen{T==[2]i64, x:X if hasarch{'SSE4.1'} and X==[8]i16} = emit{T, '_mm_cvtepi16_epi64', x}
def widen{T==[2]u64, x:X==[4]u32 if hasarch{'SSE4.1'}} = emit{T, '_mm_cvtepu32_epi64', x}; def widen{T==[2]i64, x:X if hasarch{'SSE4.1'} and X==[4]i32} = emit{T, '_mm_cvtepi32_epi64', x}
def widen{T==[2]f64, x:X=[_]E if hasarch{'SSE4.1'} and w128i{X} and width{E}<32} = widen{T, widen{[4]i32, x}}
def narrow{(i8 ), x:X if hasarch{'SSE4.1'} and w128i{X,32}} = sel{[16]u8, [16]i8~~x, make{[16]i8, 0,4,8,12, 0,0,0,0, 0,0,0,0, 0,0,0,0}}
def narrow{(i16), x:X if hasarch{'SSE4.1'} and w128i{X,32}} = sel{[16]u8, [8]i16~~x, make{[16]i8, 0,1,4,5, 8,9,12,13, 0,0,0,0, 0,0,0,0}}

View File

@ -12,13 +12,49 @@ include './avx'
include './avx2'
include './avx512'
local def fmtwidth{T} = fmtnat{width{T}}
local def has_bw{V} = hasarch{match (width{V}) { {128}=>'SSE2'; {256}=>'AVX2'; {512}=>'AVX512BW' }}
local def intrin{V, ...rest} = merge{'_mm', if (width{V}==128) '' else fmtnat{width{V}}, '_', ...rest}
local def intrin_t{V=[_]E, ...rest} = intrin{V, ...rest, '_', match (E) {
local def intrin{V, ...rest} = merge{'_mm', if (width{V}==128) '' else fmtwidth{V}, '_', ...rest}
local def scal_q{q, E} = match (E) {
{(f32)} => 'ps'
{(f64)} => 'pd'
{T} => merge{'ep', quality{T}, fmtnat{width{T}}}
}}
{_} => merge{'ep', q, fmtwidth{E}}
}
local def scal{E} = scal_q{quality{E}, E}
local def intrin_t{V=[_]E, ...rest} = intrin{V, ...rest, '_', scal_q{quality{E}, E}}
local def intrin_i{V=[_]E, ...rest} = intrin{V, ...rest, '_', scal_q{'i', E}}
local def vec_s{V=[_]E} = match (E) { # e.g. ps / pd / si128
{(f32)} => 'ps'
{(f64)} => 'pd'
{_} => merge{'si', fmtwidth{V}}
}
local def vec_l{V=[_]E} = merge{match (E) { # e.g. ps128 / pd128 / si128
{(f32)} => 'ps'
{(f64)} => 'pd'
{_} => 'si'
}, fmtwidth{V}}
local def vec_x{V=[k]E} = { # e.g. i64x2 / f32x4 / f64x2
merge{if (isint{E}) 'i' else 'f', fmtwidth{E}, 'x', fmtnat{k}}
}
local def x86_vec_cvt{name, W, D=[_]E, x:X=[_]E} = emit{D, intrin{W, name, vec_l{X}, '_', vec_l{D}}, x}
def undef_promote{D=[kd]E, x:X=[ks]E if kd>ks} = x86_vec_cvt{'cast', D, D, x}
def zero_promote{D=[kd]E, x:X=[ks]E if kd>ks} = x86_vec_cvt{'zext', D, D, x}
def extract{D=[kd]E, x:X=[ks]E, i if kd<ks and int_idx{i, ks/kd}} = match (width{X}, i) {
{_, 0} => x86_vec_cvt{'cast', X, D, x}
{256, _} => emit{D, intrin{X, 'extract', if (hasarch{'AVX2'} and isint{E}) 'i' else 'f', '128_', vec_s{X}}, x, i}
{512, _} => {
def Z = if (width{E}<32) re_el{i32, D} else D
emit{D, intrin_i{re_el{eltype{Z},X}, 'extract', vec_x{Z}}, x, i}
}
}
def half{x:[k]E, i} = extract{[k/2]E, x, i}
@ -38,3 +74,13 @@ def absdiff_sum{8, a:V=[k](u8), b:V if has_bw{V}} = {
def mul_sum{2, a:V=[k](i16), b:V if has_bw{V}} = {
emit{[k/2]i32, intrin{V, 'madd_epi16'}, a, b}
}
def low_elts{n, x:V=[k]E} = extract{[__max{128/width{E},n}]E, x, 0}
def widen{D=[k]DE, x:S=[k0]SE if isint{DE} and quality{DE}==quality{SE} and DE>SE and k<=k0 and hasarch{match (width{D}) {
{128} => 'SSE4.1'
{256} => 'AVX2'
{512} => if (width{DE}<=16) 'AVX512BW' else 'AVX512F'
}}} = {
emit{D, intrin_i{D, 'cvtep', if (isunsigned{SE}) 'u' else 'i', fmtwidth{SE}}, low_elts{k, x}}
}