uCBQN/src/singeli/src/squeeze.singeli
dzaima 92f40ddbe2 mask.singeli def renames
more bits of renames
2025-02-03 02:13:39 +02:00

161 lines
5.4 KiB
Plaintext

include './base'
include './mask'
include './cbqnDefs'
include './vecfold'
def is_sNaN{x:[_](u64)} = tup{any_hom, in_range_len{x<<1, (0xFFE<<52)+2, (1<<52)-2}}
def is_sNaN{x:[2](u64) if hasarch{'X86_64'} and not hasarch{'SSE4.2'}} = { # avoiding i64 comparisons
def nan = unord{[2]f64~~x, [2]f64~~x}
def qnan = re_el{u64, [4]u32**0xFFF8_0000 == ([4]u32~~x | [4]u32**0x8000_0000)}
tup{any_top, nan &~ qnan}
}
def any_sNaN{M, ...xs} = {
def {any, are} = flip{each{is_sNaN, xs}}
one_val{any}{M{tree_fold{|, are}}}
}
def any_nonC32{M, x:[_](u64)} = any_hom{M{~in_range_len{x, cbqn_c32Tag{}<<48, 1<<48}}}
def any_nonC32{M, x:T=[_]_ if hasarch{'X86_64'}} = {
def H = re_el{u32, T}
def ne = H~~x != H**cast_i{u32, cbqn_c32Tag{}<<16}
any_top{M{T~~ne}}
}
def any_nonC32{(mask_none), x:[k](u64), y:[k](u64)} = {
def T32 = [k*2]u32
def hi = shuf_ind{T32~~x, T32~~y, match(k) {
{2} => tup{1,3,5,7}
{4} => tup{1,3,9,11,5,7,13,15} # all odd indices, in the order that vshufps can handle
}}
anyne{hi, T32**cast_i{u32, cbqn_c32Tag{}<<16}, mask_none}
}
fn squeeze{vw, X, CHR, B}(x0:*void, len:ux) : u32 = {
def bulk = vw / width{X}
def XV = [bulk]X
def xb = tup{XV,*X~~x0}
# fold with either Max or Bitwise Or, truncating/zero-extending elements to TE
def fold_total{TE, x:[_]T} = cast_i{TE, vfold{|, x}}
def fold_total{TE, x:[_]T if hasarch{'AARCH64'}} = {
if (width{T}!=64) cast_i{TE, vfold{max, x}}
else if (width{TE}==64 and bulk==2) cast_i{TE, half{x,0} | half{x,1}}
else vfold{max, narrow{TE, x}}
}
def mix_pos{x:T=[_]E} = x & (T ** ~E~~1)
def mix_neg{x:T=[_]E} = mix_pos{x} ^ T~~(x!=T**0)
def mix{x:T} = mix_pos{x} ^ (x >> (width{T}-1))
def int_acc{T} = {
minv:= T**0
maxv:= T**0
def me{M, minc, maxc} = {
minv = min{minv, M{minc}}
maxv = max{maxv, M{maxc}}
}
def me{M, vs} = {
minc:= zero_promote{T, tree_fold{min, vs}} # could pack pairs in v to low & high halves, but an extra min costs the same or less than an insert
maxc:= zero_promote{T, tree_fold{max, vs}}
me{M, minc, maxc}
}
def me{if hasarch{'AARCH64'}} = {
mint:= ty_u{vfold{min, minv}}
maxt:= ty_u{vfold{max, maxv}} &~ 1
cast_i{u32, tern{mint==0, maxt, max{maxt, -mint-1} | 2}}
}
def me{} = fold_total{u32, mix_neg{minv} | mix_pos{maxv}}
def me{'minmax'} = 1
}
def int_acc{T=[_]E if E==i32 and hasarch{'X86_64'} and not hasarch{'SSE4.1'}} = {
acc:= T**0
def me{M, vs} = {
def curr = M{tree_fold{|, each{mix, vs}}}
acc|= curr
curr
}
def me{} = fold_total{u32, acc}
def me{'minmax'} = 0
}
if (CHR) { # c16/c32/B → char
mt:= XV**0 # unused for c16
@muLoop{bulk, 2}(xs in xb, M in 'm' over len) {
def orx = M{tree_fold{|, xs}}
if (B) {
if (any_nonC32{M, ...xs}) return{3}
} else {
def bad = if (hasarch{'AARCH64'}) any_bit{if (length{xs}==2) pack{...xs,1} else pack{orx,orx,1}}
else ~and_bit_none{orx, ~XV**maxvalue{w_h{X}}}
if (bad) return{lb{width{X}}-3}
}
mt|= orx
}
def tot_max = fold_total{u32, if (B) mt & XV**32w2b1 else mt}
if (X>u32 and tot_max>=65536) return{2}
if (X>u16 and tot_max>=256) return{1}
0
} else if (X==i32 or X==i16) {
def EH = w_h{X}
def acc = int_acc{XV}
if (acc{'minmax'}) {
@muLoop{bulk, 4}(xs in xb, M in 'm' over len) {
minc:= tree_fold{min, xs}
maxc:= tree_fold{max, xs}
if (M{0}==0 and any_hom{M{(minc < XV**minvalue{EH}) | (maxc > XV**maxvalue{EH})}}) return{0xffff_ffff}
acc{M, minc, maxc}
}
} else {
@muLoop{bulk, 2}(xs in xb, M in 'm' over len) {
def mixed = acc{M, xs}
if (M{0}==0 and any_hom{mixed > XV**maxvalue{EH}}) return{0xffff_ffff}
}
}
acc{}
} else if (X==i8) {
@muLoop{bulk, 2}(xs in xb, M in 'm' over len) {
if (~and_bit_none{M{tree_fold{|, xs}}, XV ** -2}) return{2}
}
0
} else if (X==f64) {
def case_B = make_branch{tup{ux}, {bulkCont} => {
def i0 = bulkCont*bulk
x:= i0 + *u64~~x0
if (B) @muLoop{bulk, 2}(xs in tup{[bulk]u64, x}, M in 'm' over len-i0) {
if (any_sNaN{M, ...xs}) return{0xffff_fffe} # not even a number
}
return{0xffff_ffff} # float
}}
def acc = int_acc{re_el{i32, XV}}
@muLoop{bulk, 2}(xs in xb, M in 'm' over is to len) {
if (hasarch{'X86_64'}) {
def ns = each{narrow{i32,.}, xs}
if (any_hom{M{tree_fold{|, each{{ns,x} => widen{XV,ns}!=x, ns, xs}}}}) case_B{select{is, 0}}
acc{M, ns}
} else {
assert{hasarch{'AARCH64'}}
def k = length{xs}
def int = each{cvt{i64,.}, xs}
def int32 = if (k==2) pack{...int, 0} else pack{...int, [bulk]i64**0, 0}
def wd = each{{G} => cvt{f64,G{int32}}, slice{tup{widen{[bulk]i64,.}, widen_upper}, 0, k}}
if (~all_hom{M, tree_fold{&, each{==, wd, xs}}}) case_B{select{is, 0}}
acc{M, tup{int32}}
}
}
acc{}
} else assert{0}
}
export{'si_squeeze_i8', squeeze{arch_defvw, i8, 0, 0}}
export{'si_squeeze_i16', squeeze{arch_defvw, i16, 0, 0}}
export{'si_squeeze_i32', squeeze{arch_defvw, i32, 0, 0}}
export{'si_squeeze_f64', squeeze{arch_defvw, f64, 0, 0}}
export{'si_squeeze_numB', squeeze{arch_defvw, f64, 0, 1}}
export{'si_squeeze_c16', squeeze{arch_defvw, u16, 1, 0}}
export{'si_squeeze_c32', squeeze{arch_defvw, u32, 1, 0}}
export{'si_squeeze_chrB', squeeze{arch_defvw, u64, 1, 1}}