diff --git a/src/builtins/compare.c b/src/builtins/compare.c index d7aadbdc..2078c943 100644 --- a/src/builtins/compare.c +++ b/src/builtins/compare.c @@ -133,6 +133,15 @@ u8 const matchFnData[] = { // for the main diagonal, amount to shift length by; DEF_EQ_U1(16, i16) DEF_EQ_U1(32, i32) DEF_EQ_U1(f64, f64) + bool equal_f64_f64_reflexive(void* wp, void* xp, ux l, u64 data) { + bool r = true; + for (ux i = 0; i < l; i++) { + f64 w = ((f64*)wp)[i]; + f64 x = ((f64*)xp)[i]; + r&= (w==x) | (w!=w & x!=x); + } + return r; + } #undef DEF_EQ_U1 #define DEF_EQ_I(NAME, S, T, INIT) \ @@ -154,15 +163,6 @@ u8 const matchFnData[] = { // for the main diagonal, amount to shift length by; #undef DEF_EQ #endif static NOINLINE bool notEq(void* a, void* b, ux l, u64 data) { assert(l>0); return false; } -static NOINLINE bool eequalFloat(void* wp, void* xp, ux l, u64 data) { - bool r = true; - for (ux i = 0; i < l; i++) { - f64 w = ((f64*)wp)[i]; - f64 x = ((f64*)xp)[i]; - r&= (w==x) | (w!=w & x!=x); - } - return r; -} #define MAKE_TABLE(NAME, F64_F64) \ INIT_GLOBAL MatchFn NAME[] = { \ @@ -170,13 +170,13 @@ INIT_GLOBAL MatchFn NAME[] = { \ F(1_8), F(8_8), F(s8_16), F(s8_32), F(s8_f64), notEq, notEq, notEq, \ F(1_16), F(s8_16), F(8_8), F(s16_32), F(s16_f64), notEq, notEq, notEq, \ F(1_32), F(s8_32), F(s16_32), F(8_8), F(s32_f64), notEq, notEq, notEq, \ - F(1_f64), F(s8_f64), F(s16_f64), F(s32_f64), F64_F64, notEq, notEq, notEq, \ + F(1_f64), F(s8_f64), F(s16_f64), F(s32_f64), F(F64_F64), notEq, notEq, notEq, \ notEq, notEq, notEq, notEq, notEq, F(8_8), F(u8_16), F(u8_32), \ notEq, notEq, notEq, notEq, notEq, F(u8_16), F(8_8), F(u16_32), \ notEq, notEq, notEq, notEq, notEq, F(u8_32), F(u16_32), F(8_8), \ }; -MAKE_TABLE(matchFns, F(f64_f64)); -MAKE_TABLE(matchFnsR, eequalFloat); +MAKE_TABLE(matchFns, f64_f64); +MAKE_TABLE(matchFnsR, f64_f64_reflexive); #undef MAKE_TABLE #undef F diff --git a/src/singeli/src/equal.singeli b/src/singeli/src/equal.singeli index 1ce0adb6..9c9e4e85 100644 --- a/src/singeli/src/equal.singeli +++ b/src/singeli/src/equal.singeli @@ -66,6 +66,38 @@ fn equal{W, X}(w:*void, x:*void, l:ux, d:u64) : u1 = { 1 } +def eq_reflexive{a:V=[k](f64), b:V} = (a==b) | ((a!=a) & (b!=b)) + +def any_ne_reflexive_qnan{M, a:V=[k](f64), b:V} = { # (a==b) | (isQNaN{a} & isQNaN{b}); assumes at least one arg isn't sNaN + def U = ty_u{V} + def t1 = U~~a & U~~b + + def ne = a != b + def t2 = ne & U**0x7FF8_0000_0000_0000 + + def andn_bit_none{M, x:T, y:T} = ~any_bit{M{x&~y}} + def andn_bit_none{M, x:T, y:T if M{0}==0 and hasarch{'X86_64'}} = andn_bit_none{x,y} + ~andn_bit_none{M, t2, t1} +} + +fn equal_reflexive{}(w:*void, x:*void, l:ux, d:u64) : u1 = { + def w = *f64~~w + def x = *f64~~x + def bulk = arch_defvw / width{f64} + def unr = if (hasarch{'AARCH64'}) 2 else 1 + def V = [bulk]f64 + @for_mu{bulk, unr}(w in tup{V,w}, x in tup{V,x}, M in 'm' over i to l) { + if (hasarch{'SSE4.1'}) { + each{{w,x} => { + if (any_ne_reflexive_qnan{M,w,x}) return{0} + }, w, x} + } else { + if (any_hom{M, ...each{{w,x} => ~eq_reflexive{w,x}, w, x}}) return{0} + } + } + 1 +} + export{'simd_equal_1_1', equal{u1, u1}} export{'simd_equal_1_8', equal{u1, u8}} export{'simd_equal_1_16', equal{u1, u16}} @@ -82,6 +114,7 @@ export{'simd_equal_s8_f64', equal{i8, f64}} export{'simd_equal_s16_f64', equal{i16, f64}} export{'simd_equal_s32_f64', equal{i32, f64}} export{'simd_equal_f64_f64', equal{f64, f64}} +export{'simd_equal_f64_f64_reflexive', equal_reflexive{}} export{'simd_equal_u8_16', equal{u8, u16}} export{'simd_equal_u8_32', equal{u8, u32}}