singeli cast changes

This commit is contained in:
dzaima 2022-02-25 18:22:26 +02:00
parent de4914991a
commit 7af80e9e3c

View File

@ -1,10 +1,4 @@
# ugh the intrinsics have their own broken type system
def r_d2i{R, a:T} = emit{R, '_mm256_castpd_si256', a}
def r_f2i{R, a:T} = emit{R, '_mm256_castps_si256', a}
def r_i2d{a:T} = emit{[4]f64, '_mm256_castsi256_pd', a}
def r_f2d{a:T} = emit{[4]f64, '_mm256_castps_pd', a}
def r_d2f{a:T} = emit{[8]f32, '_mm256_castpd_ps', a}
def r_i2f{a:T} = emit{[8]f32, '_mm256_castsi256_ps', a}
# various utilities
def isunsigned{T} = isint{T} & ~issigned{T}
@ -18,13 +12,10 @@ def w256{T} = width{T}==256
def isintv{T,w} = isintv{T} & (width{eltype{T}}==w)
def cast_vp{T, x & w256{T}} = emit{*T, '(void*)', x}
def cast_v{R, x:S & w256{R} & w256{S} & isintv{S} & isintv{R}} = emit{R, '', x}
def cast_v{R, x:S & w256{R} & w256{S} & isf64v{S} & isintv{R}} = r_d2i{R, x}
def cast_v{R, x:S & w256{R} & w256{S} & isf32v{S} & isintv{R}} = r_f2i{R, x}
def cast_v{R, x:S & w256{R} & w256{S} & isf64v{S} & isf32v{R}} = r_d2f{x}
def cast_v{R, x:S & w256{R} & w256{S} & isintv{S} & isf32v{R}} = r_i2f{x}
def cast_v{R, x:S & w256{R} & w256{S} & isf32v{S} & isf64v{R}} = r_f2d{x}
def cast_v{R, x:S & w256{R} & w256{S} & isintv{S} & isf64v{R}} = r_i2d{x}
def cast_v{R, x:S & w256{R} & isintv{R} & w256{S}} = emit{R, '(__m256i)', x}
def cast_v{R, x:S & w256{R} & isf32v{R} & w256{S}} = emit{R, '(__m256)', x}
def cast_v{R, x:S & w256{R} & isf64v{R} & w256{S}} = emit{R, '(__m256d)', x}
def ty_vu{T & w256{T} & issignedv{T}} = [vcount{T}](ty_iu{eltype{T}})
def ty_vs{T & w256{T} & isunsignedv{T}} = [vcount{T}](ty_is{eltype{T}})
def forv{T & w256{T}} = forc{{v}=>cast_vp{T,v}}
@ -54,14 +45,14 @@ def make{T==[8]i32,a,b,c,d,e,f,g,h} = emit{T,'_mm256_set_epi32',ext{i32,h},ext{i
def make{T==[16]i16,a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p} = emit{T,'_mm256_set_epi16', ext{i16,p},ext{i16,o},ext{i16,n},ext{i16,m},ext{i16,l},ext{i16,k},ext{i16,j},ext{i16,i},ext{i16,h},ext{i16,g},ext{i16,f},ext{i16,e},ext{i16,d},ext{i16,c},ext{i16,b},ext{i16,a}}
def make{T==[32]i8,a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P} = emit{T,'_mm256_set_epi8', ext{i8,P},ext{i8,O},ext{i8,N},ext{i8,M},ext{i8,L},ext{i8,K},ext{i8,J},ext{i8,I},ext{i8,H},ext{i8,G},ext{i8,F},ext{i8,E},ext{i8,D},ext{i8,C},ext{i8,B},ext{i8,A}, ext{i8,p},ext{i8,o},ext{i8,n},ext{i8,m},ext{i8,l},ext{i8,k},ext{i8,j},ext{i8,i},ext{i8,h},ext{i8,g},ext{i8,f},ext{i8,e},ext{i8,d},ext{i8,c},ext{i8,b},ext{i8,a}}
def __xor{a:T, b:T & w256{T} & isintv{T}} = r_f2i{T, emit{[8]f32, '_mm256_xor_ps', r_i2f{a}, r_i2f{b}}}
def __and{a:T, b:T & w256{T} & isintv{T}} = r_f2i{T, emit{[8]f32, '_mm256_and_ps', r_i2f{a}, r_i2f{b}}}
def __or {a:T, b:T & w256{T} & isintv{T}} = r_f2i{T, emit{[8]f32, '_mm256_or_ps', r_i2f{a}, r_i2f{b}}}
def __xor{a:T, b:T & w256{T} & isintv{T}} = cast_v{T, emit{[8]f32, '_mm256_xor_ps', cast_v{[8]f32, a}, cast_v{[8]f32, b}}}
def __and{a:T, b:T & w256{T} & isintv{T}} = cast_v{T, emit{[8]f32, '_mm256_and_ps', cast_v{[8]f32, a}, cast_v{[8]f32, b}}}
def __or {a:T, b:T & w256{T} & isintv{T}} = cast_v{T, emit{[8]f32, '_mm256_or_ps', cast_v{[8]f32, a}, cast_v{[8]f32, b}}}
def __not{a:T & w256{T} & isunsignedv{T}} = a ^ broadcast{T, ~cast{eltype{T},0}}
# f64 comparison
def f64cmpAVX{a,b,n} = r_d2i{[4]u64, emit{[4]f64, '_mm256_cmp_pd', a, b, n}}
def f64cmpAVX{a,b,n} = cast_v{[4]u64, emit{[4]f64, '_mm256_cmp_pd', a, b, n}}
def __eq{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b, 0}
def __ne{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b, 4}
def __gt{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b,30}
@ -77,7 +68,7 @@ def __div{a:T,b:T & T==[8]f32} = emit{T, '_mm256_div_ps', a, b}
def max{a:T,b:T & T==[8]f32} = emit{T, '_mm256_max_ps', a, b}
def min{a:T,b:T & T==[8]f32} = emit{T, '_mm256_min_ps', a, b}
def sqrt{a:T,b:T & T==[8]f32} = emit{T, '_mm256_sqrt_ps', a, b}
def abs{a:[8]f32} = emit{[8]f32, '_mm256_and_ps', a, r_i2f{broadcast{[8]u32, 0x7FFFFFFF}}}
def abs{a:[8]f32} = emit{[8]f32, '_mm256_and_ps', a, cast_v{[8]f32, broadcast{[8]u32, 0x7FFFFFFF}}}
def floor{a:[8]f32} = emit{[8]f32, '_mm256_floor_ps', a}
def ceil{a:[8]f32} = emit{[8]f32, '_mm256_ceil_ps', a}
@ -89,12 +80,12 @@ def __div{a:T,b:T & T==[4]f64} = emit{T, '_mm256_div_pd', a, b}
def max{a:T,b:T & T==[4]f64} = emit{T, '_mm256_max_pd', a, b}
def min{a:T,b:T & T==[4]f64} = emit{T, '_mm256_min_pd', a, b}
def sqrt{a:T,b:T & T==[4]f64} = emit{T, '_mm256_sqrt_pd', a, b}
def abs{a:[4]f64} = emit{[4]f64, '_mm256_and_pd', a, r_i2d{broadcast{[4]u64, (cast{u64,1}<<63)-1}}}
def abs{a:[4]f64} = emit{[4]f64, '_mm256_and_pd', a, cast_v{[4]f64, broadcast{[4]u64, (cast{u64,1}<<63)-1}}}
def floor{a:[4]f64} = emit{[4]f64, '_mm256_floor_pd', a}
def ceil{a:[4]f64} = emit{[4]f64, '_mm256_ceil_pd', a}
def getmask{x:T & w256{T} & 32==width{eltype{T}}} = emit{u8, '_mm256_movemask_ps', r_i2f{x}}
def getmask{x:T & w256{T} & 64==width{eltype{T}}} = emit{u8, '_mm256_movemask_pd', r_i2d{x}}
def getmask{x:T & w256{T} & 32==width{eltype{T}}} = emit{u8, '_mm256_movemask_ps', cast_v{[8]f32, x}}
def getmask{x:T & w256{T} & 64==width{eltype{T}}} = emit{u8, '_mm256_movemask_pd', cast_v{[4]f64, x}}
def any{x:T & w256{T} & isintv{T}} = getmask{x}!=0 # assumes elements of x all have equal bits (avx2 utilizes this for 16 bits)
def anyneg{x:T & w256{T} & issignedv{T}} = getmask{x}!=0