fix float normalization from a •bit._cast result mutating the original

This commit is contained in:
dzaima 2023-05-11 18:20:33 +03:00
parent ea4da381f2
commit 03a28e4e34
5 changed files with 98 additions and 29 deletions

View File

@ -375,6 +375,50 @@ B count_c2(B t, B w, B x) {
return reduceI32Width(r, wia);
}
// if nanBad and input contains a NaN, doesn't consume and returns m_f64(0)
// otherwise, consumes and returns an array with -0 (and NaNs if !nanBad) normalized
B asNormalized(B x, usz n, bool nanBad) {
f64* fp = f64any_ptr(x);
ux i = 0;
#if SINGELI_SIMD
i = simd_search_normalizable(fp, n);
if (i!=n) goto some;
#else
for (; i < n; i++) if (r_f64u(fp[i])==r_f64u(-0.0) || fp[i]!=fp[i]) goto some;
#endif
return x;
some:;
f64* rp;
B r;
if (TY(x)==t_f64arr && reusable(x)) {
rp = fp;
r = x;
} else {
r = m_f64arrc(&rp, x);
COPY_TO(rp, el_f64, 0, x, 0, i);
}
if (nanBad) {
#if SINGELI_SIMD
if (RARE(simd_copy_ordered(rp+i, fp+i, n-i))) goto bad;
#else
for (; i < n; i++) {
if (RARE(fp[i]!=fp[i])) goto bad;
rp[i] = fp[i]+0.0;
}
#endif
} else {
for (; i < n; i++) rp[i] = normalizeFloat(fp[i]);
}
if (r.u!=x.u) decG(x);
return r;
bad:
if (r.u!=x.u) mm_free(v(r));
return m_f64(0);
}
void search_init(void) {
{ u64* p; Arr* a=m_bitarrp(&p, 1); arr_shAtm(a); *p= 0; gc_add(enclosed_0=taga(a)); }

View File

@ -77,16 +77,15 @@ static inline B shift_ne(B x, usz n, u8 lw, bool r0) { // consumes x
decG(x); return r;
}
static bool canCompare64_norm(B x, usz n) {
u8 e = TI(x,elType);
B asNormalized(B x, usz n, bool nanBad); // from search.c
SHOULD_INLINE bool canCompare64_norm(B* x, void** xp, usz n) {
u8 e = TI(*x,elType);
if (e == el_B) return 0;
if (e == el_f64) {
f64* pf = f64any_ptr(x);
for (usz i = 0; i < n; i++) {
f64 f = pf[i];
if (f!=f) return 0;
pf[i] = f==0? 0 : f;
}
B r = asNormalized(*x, n, true);
if (r.u == m_f64(0).u) return 0;
*x = r;
*xp = tyany_ptr(r);
}
return 1;
}
@ -252,14 +251,14 @@ B memberOf_c1(B t, B x) {
for (usz i=0; i<n; i++) { u##T j=xp[i]; rp[i]=tab[j]; tab[j]=0; } \
decG(x); TFREE(tab); \
return taga(cpyBitArr(r))
if (lw == 3) { if (n<8) { BRUTE(8); } else { LOOKUP(8); } }
if (lw == 4) { if (n<8) { BRUTE(16); } else { LOOKUP(16); } }
if (lw==3) { if (n<8) { BRUTE(8); } else { LOOKUP(8); } }
if (lw==4) { if (n<8) { BRUTE(16); } else { LOOKUP(16); } }
#undef LOOKUP
#define HASHTAB(T, W, RAD, STOP, THRESH) T* xp = (T*)xv; SELFHASHTAB( \
T, W, RAD, STOP, \
1, taga(cpyBitArr(r)), hash[j]=h; rp[i]=k!=h;, \
1, THRESH, 0,,,)
if (lw == 5) {
if (lw==5) {
if (n<12) { BRUTE(32); }
i8* rp; B r = m_i8arrv(&rp, n);
HASHTAB(u32, 32, 1, n/2, sz==msz? 1 : sz>=(1<<15)? 3 : 5)
@ -285,7 +284,7 @@ B memberOf_c1(B t, B x) {
RADIX_LOOKUP_32(1, =0)
return taga(cpyBitArr(r));
}
if (lw == 6 && canCompare64_norm(x, n)) {
if (lw==6 && canCompare64_norm(&x, &xv, n)) {
if (n<20) { BRUTE(64); }
i8* rp; B r = m_i8arrv(&rp, n);
HASHTAB(u64, 64, 0, n, sz==msz? 0 : sz>=(1<<18)? 0 : sz>=(1<<14)? 3 : 5)
@ -397,7 +396,7 @@ B count_c1(B t, B x) {
RADIX_LOOKUP_32(0, ++)
return num_squeeze(r);
}
if (lw == 6 && canCompare64_norm(x, n)) {
if (lw==6 && canCompare64_norm(&x, &xv, n)) {
if (n<20) { BRUTE(64); }
i32* rp; B r = m_i32arrv(&rp, n);
HASHTAB(u64, 64, 0, n, sz==msz? 0 : sz>=(1<<18)? 0 : sz>=(1<<14)? 3 : 5)
@ -435,7 +434,7 @@ B indexOf_c1(B t, B x) {
if (csz==0) goto zeroRes;
u8 lw = cellWidthLog(x);
void* xv = tyany_ptr(x);
if (lw == 0) {
if (lw==0) {
B r = 1&*(u64*)xv ? bit_negate(x) : x;
return C1(shape, r);
}
@ -502,7 +501,7 @@ B indexOf_c1(B t, B x) {
HASHTAB(u32, 32, sz==msz? 0 : sz>=(1<<18)? 1 : sz>=(1<<14)? 4 : 6)
decG(r); // Fall through
}
if (lw==6 && canCompare64_norm(x, n)) {
if (lw==6 && canCompare64_norm(&x, &xv, n)) {
if (n<16) { BRUTE(64); }
i32* rp; B r = m_i32arrv(&rp, n);
u64* xp = tyany_ptr(x);

View File

@ -61,6 +61,8 @@ def __gt{a:T,b:T & T==[8]f32} = f32cmpAVX{a,b,30}; def __gt{a:T,b:T & T==[4]f64}
def __ge{a:T,b:T & T==[8]f32} = f32cmpAVX{a,b,29}; def __ge{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b,29}
def __lt{a:T,b:T & T==[8]f32} = f32cmpAVX{a,b,17}; def __lt{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b,17}
def __le{a:T,b:T & T==[8]f32} = f32cmpAVX{a,b,18}; def __le{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b,18}
def unord{a:T,b:T & T==[8]f32} = f32cmpAVX{a,b,3}
def unord{a:T,b:T & T==[4]f64} = f64cmpAVX{a,b,3}
# f32 arith
def __add{a:T,b:T & T==[8]f32} = emit{T, '_mm256_add_ps', a, b}

View File

@ -25,8 +25,7 @@ def findFirst{C, M, F, ...v1} = {
F{...args}
}
fn search{A, E}(l:*void, e0:A, n:u64) : u64 = {
def e = if (A==E) e0 else cast_i{E, e0}
def search{E, x, n:u64, OP} = {
def bulk = arch_defvw/width{E}
def VT = [bulk]E
def end = makeBranch{
@ -35,7 +34,7 @@ fn search{A, E}(l:*void, e0:A, n:u64) : u64 = {
}
muLoop{bulk, tern{arch_defvw>=256, 1, 2}, n, {is, M} => {
eq:= eachx{==, loadBatch{*E~~l, is, VT}, VT**e}
eq:= each{OP, loadBatch{*E~~x, is, VT}}
if (homAny{M{tree_fold{|, eq}}}) {
findFirst{
{i,c} => homAny{c},
@ -48,7 +47,31 @@ fn search{A, E}(l:*void, e0:A, n:u64) : u64 = {
n
}
export{'simd_search_u8', search{u64, u8}}
export{'simd_search_u16', search{u64, u16}}
export{'simd_search_u32', search{u64, u32}}
export{'simd_search_f64', search{f64, f64}}
fn searchOne{A, E}(x:*void, e0:A, len:u64) : u64 = {
def e = if (A==E) e0 else cast_i{E, e0}
search{E, x, len, {c:VT} => c == VT**e}
}
def isNegZero{x:T} = to_el{u64,x} == to_el{u64, T ** -f64~~0}
fn searchNormalizable{}(x:*f64, len:u64) : u64 = {
search{f64, x, len, {c:VT} => isNegZero{c} | (c!=c)}
}
fn copyOrdered{}(r:*f64, x:*f64, len:u64) : u1 = {
def E = f64
def bulk = arch_defvw/width{E}
def VT = [bulk]E
maskedLoop{bulk, len, {i, M} => {
c:= loadBatch{x, i, VT}
if (homAny{M{c!=c}}) return{1}
storeBatch{r, i, c + VT**0, M}
}}
0
}
export{'simd_search_u8', searchOne{u64, u8}}
export{'simd_search_u16', searchOne{u64, u16}}
export{'simd_search_u32', searchOne{u64, u32}}
export{'simd_search_f64', searchOne{f64, f64}}
export{'simd_search_normalizable', searchNormalizable{}}
export{'simd_copy_ordered', copyOrdered{}}

View File

@ -3,7 +3,7 @@
#include "hash.h"
#include "time.h"
B asNormalized(B x, usz n, bool nanBad); // from search.c
NOINLINE u64 bqn_hashObj(B x, const u64 secret[4]) { // TODO manual separation of atom & arr probably won't be worth it when there are actually sane typed array hashing things
if (isArr(x)) {
usz xia = IA(x);
@ -18,12 +18,13 @@ NOINLINE u64 bqn_hashObj(B x, const u64 secret[4]) { // TODO manual separation o
void* data;
u64 bytes;
switch(xe) { default: UD;
case el_bit: bcl(x,xia); data = bitarr_ptr(x); bytes = (xia+7)>>3; break;
case el_i8: case el_c8: data = tyany_ptr(x); bytes = xia*1; break;
case el_i16: case el_c16: data = tyany_ptr(x); bytes = xia*2; break;
case el_i32: case el_c32: data = tyany_ptr(x); bytes = xia*4; break;
case el_f64: data = f64any_ptr(x); bytes = xia*8;
for (ux i = 0; i < xia; i++) ((f64*)data)[i] = normalizeFloat(((f64*)data)[i]);
case el_bit: bcl(x,xia); bytes = (xia+7)>>3; data = bitarr_ptr(x); break;
case el_i8: case el_c8: bytes = xia*1; data = tyany_ptr(x); break;
case el_i16: case el_c16: bytes = xia*2; data = tyany_ptr(x); break;
case el_i32: case el_c32: bytes = xia*4; data = tyany_ptr(x); break;
case el_f64: bytes = xia*8;
x = asNormalized(x, xia, false);
data = f64any_ptr(x);
break;
case el_B:;
data = TALLOCP(u64, xia);