full vector broadcasting via broadcast{[k*n]E, v:[k]E}

This commit is contained in:
dzaima 2025-05-01 16:15:41 +03:00
parent b1e561d7ed
commit 1d3413f6ea
5 changed files with 7 additions and 6 deletions

View File

@ -291,6 +291,7 @@ def fold_addw{x:[k]E if k <= (1<<width{E})} = fold_addw{w_d{E}, x}
def broadcast{T, v if primt{T}} = v
def broadcast{V=[_]T, v if any_num{v}} = vec_broadcast{V, if (knum{v}) v else promote{T,v}}
def broadcast{V=[k]E, v:V} = v
def make{V=[_]_, ...xs} = vec_make{V, ...xs}
def iota{V=[k]_} = make{V, ...range{k}}
def absu{a:[_]_} = ty_u{__abs{a}}

View File

@ -44,7 +44,7 @@ fn max_scan{T, up}(x:*T, len:u64) : void = {
def getsel{...x} = assert{'shuffling not supported', show{...x}}
if_inline (hasarch{'AVX2'}) {
def getsel{h:H=[16]T if width{T}==8} = {
shuf{H, pair{h,h}, .}
shuf{H, [32]T**h, .}
}
def getsel{v:V=[vl==32]T if width{T}==8} = {
def H = n_h{V}

View File

@ -1,6 +1,5 @@
def __shl{(u16)}{a:T, b} = T~~(re_el{u16,a}<<b) # for x86's lack of u8 shift
def __shr{(u16)}{a:T, b} = T~~(re_el{u16,a}>>b)
def broadcast{[(n*2)]E, x:[n]E} = pair{x, x}
def pow2_up{v, least} = __max{least, 1<<ceil_log2{v}} # least ⌈ ⌈⌾(2⊸⋆⁼) v
# make a LUT of at least nt elements in tab, to be indexed by [ni_real≥ni]u8
@ -191,8 +190,7 @@ def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='c' and E>=u16} = 0
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='i' and E==u64 and nt>16} = 0
def lut16{tab:([16]u8), idxs:([16]u8)} = shuf{[16]u8, tab, idxs}
def lut16{tab:([16]u8), idxs:([32]u8) if hasarch{'X86_64'}} = shuf{[16]u8, pair{tab, tab}, idxs}
def lut16{tab:([16]u8), idxs:IV} = shuf{[16]u8, IV**tab, idxs}
def shuf_u8bits{inds:(*u8), ni} = 0
def shuf_u8bits{inds:(*u8), ni if has_sel} = {

View File

@ -431,8 +431,7 @@ def rep_const_bool_div8{wv, x, r, nw if has_simd} = {
}
def run8{rep_bytes} = {
i:usz = 0; while (1) { check_done{}
xh := load{*[16]u8~~(*ty_u{vl}~~x + i)}; ++i
xv := if (avx2) pair{xh, xh} else xh
xv := V**load{*[16]u8 ~~ (*ty_u{vl}~~x + i)}; ++i
xe := rep_bytes{xv}
output{(xe & mkV{1 << (iV % 8)}) > V**0}
}

View File

@ -80,6 +80,9 @@ def extract{D=[kd]E, x:X=[ks]E, i if kd<ks and int_idx{i, ks/kd}} = match (width
}
def half{x:[k]E, i} = extract{[k/2]E, x, i}
def broadcast{[dk]E, v:[k]E if k*2 == dk} = pair{v, v}
def broadcast{[dk]E, v:[k]E if width{[dk]E}==512 and width{[k]E}==128} = emit{[dk]E, intrin{[dk]E, merge{'broadcast_', vec_x{to_x{[k]E}}}}, v}
def x86_low_elts{n, x:V=[k]E} = extract{x86_vec_low{n,E}, x, 0}