allow specifying result type of fold_addw, x86 impls for it

This commit is contained in:
dzaima 2025-03-24 17:38:45 +02:00
parent 7237ad1abb
commit 5df3712748
3 changed files with 31 additions and 19 deletions

View File

@ -173,7 +173,7 @@ def lvec = match { {[n]T, n, (width{T})} => 1; {T, n, w} => 0 }
# base cases
def {
absu,and_bit_none,andnz,load_bits,blend,blend_units,clmul,cvt,extract,fold_addw,half,
absu,and_bit_none,andnz,load_bits,blend,blend_units,clmul,cvt,extract,half,
all_bit,any_bit,blend_bit,
all_hom,any_hom,blend_hom,hom_to_int,store_masked_hom,store_blended_hom,
all_top,any_top,blend_top,top_to_int,store_masked_top,store_blended_top,
@ -287,6 +287,8 @@ def zero_promote{T, x:T} = x
def cvt{T, x:[_]T} = x
def extract{V=[k]E, x:V, 0} = x
def store_narrow_relaxed{p:*DE, x:[k]E} = store{p, narrow{DE, x}, k}
def fold_addw{E, x:V=[_]E} = vfold{+, x}
def fold_addw{x:[k]E if k <= (1<<width{E})} = fold_addw{w_d{E}, x}
def broadcast{T, v if primt{T}} = v
def broadcast{V=[_]T, v} = vec_broadcast{V, if (knum{v}) v else promote{T,v}}

View File

@ -59,7 +59,7 @@ def clz{x:T if nvecu{T} and elwidth{T}<=32} = emit{T, ntyp{'vclz', T}, x}
def cls{x:T if nveci{T} and elwidth{T}<=32} = ty_u{T}~~emit{ty_s{T}, ntyp{'vcls', T}, x}
def fold_add {a:T=[_]E if nvec{T}} = emit{E, ntyp{'vaddv', T}, a}
def fold_addw{a:T=[_]E if nveci{T}} = emit{w_d{E}, ntyp{'vaddlv', T}, a}
def fold_addw{D, a:T=[_]E if nveci{T} and D>E} = cast_i{D, emit{w_d{E}, ntyp{'vaddlv', T}, a}}
def fold_min {a:T=[_]E if nvec{T} and ~nveci{T,64}} = emit{E, ntyp{'vminv', T}, a}
def fold_max {a:T=[_]E if nvec{T} and ~nveci{T,64}} = emit{E, ntyp{'vmaxv', T}, a}
def vfold{(__min), x:T if nvec{T} and ~nveci{T,64}} = fold_min{x}

View File

@ -1,18 +1,28 @@
# Fold associative/commutative operation across a register
def vfold{F, x:V=[_]T if w128{V} and hasarch{'X86_64'}} = {
c:= x
def EW = width{T}
if (EW<=64) c = F{c, shuf{u64, c, 1,0}}
if (EW<=32) c = F{c, shuf{u32, c, 1,0}}
if (EW<=16) c = F{c, vec_shuffle16_lo{c, tup{1,0,3,2}}}
if (EW==8) { v:=extract{[8]i16~~c, 0}; F{cast_i{T, v}, cast_i{T, v>>8}} }
else extract{c, 0}
if_inline (hasarch{'X86_64'}) {
# Fold associative/commutative operation across a register
def vfold{F, x:V=[_]T if w128{V}} = {
c:= x
def EW = width{T}
if (EW<=64) c = F{c, shuf{u64, c, 1,0}}
if (EW<=32) c = F{c, shuf{u32, c, 1,0}}
if (EW<=16) c = F{c, vec_shuffle16_lo{c, tup{1,0,3,2}}}
if (EW==8) { v:=extract{[8]i16~~c, 0}; F{cast_i{T, v}, cast_i{T, v>>8}} }
else extract{c, 0}
}
def vfold{(__add), x:V=[16]E if width{E}==8} = {
c:= x + shuf{u64, x, 1,0}
cast_i{E, extract{absdiff_sum{8, ty_u{c}, [16]u8**0}, 0}}
}
def vfold{F, x:T if w256{T}} = vfold{F, F{half{x, 0}, half{x, 1}}}
# def fold_addw{DE, x:V=[k](i8) if DE>i8 and DE<=i64} = cast_i{DE, vfold{+, mul_sum_sat{2, x, [k]u8**1}}}
def fold_addw{DE, x:V=[k](u8) if DE> u8} = cast_i{DE, vfold{+, absdiff_sum{8, x, V**0}}}
def fold_addw{DE, x:V=[k](i16) if DE>i16} = cast_i{DE, vfold{+, mul_sum{2, x, V**1}}}
def fold_addw{DE, x:V=[k](i8) if DE> i8} = DE ~~ fold_addw{ty_u{DE}, ty_u{x} ^ [k]u8 ** (1<< 7)} - k*(1<<7)
def fold_addw{DE, x:V=[k](u16) if DE>u16} = DE ~~ fold_addw{ty_s{DE}, ty_s{x} ^ [k]i16** -(1<<15)} + k*(1<<15)
def fold_addw{DE, x:V=[k](i32) if DE>i32} = DE ~~ fold_addw{ty_u{DE}, ty_u{x} ^ [k]u32** (1<<31)} - k*(1<<31)
def fold_addw{DE, x:V=[k](u32) if DE>u32} = { vfold{+, el_m{blend_units{V**0,x,1,0}} + el_m{x}>>32} }
}
def vfold{(__add), x:V=[16]E if hasarch{'X86_64'} and width{E}==8} = {
c:= x + shuf{u64, x, 1,0}
cast_i{E, extract{absdiff_sum{8, ty_u{c}, [16]u8**0}, 0}}
}
def vfold{F, x:T if w256{T} and hasarch{'X86_64'}} = vfold{F, F{half{x, 0}, half{x, 1}}}