uCBQN/src/singeli/src/equal.singeli
2022-04-25 03:06:11 +03:00

88 lines
2.4 KiB
Plaintext

include './base'
include './cbqnDefs'
include './f64'
include './sse3'
include './avx'
include './avx2'
include './mask'
def swap{w,x} = {
t:= w
w = x
x = t
}
equal{W, X}(w:*u8, x:*u8, l:u64, d:u64) : u1 = {
def bulk = 256 / width{X}
if (W!=X) if (d!=0) swap{w,x}
if (W==u1) {
if (X==u1) { # bitarr ≡ bitarr
maskedLoop{256, l, {i, M} => {
cw:= load{*[32]u8 ~~ w, i}
cx:= load{*[32]u8 ~~ x, i}
if (anyneBit{cw,cx,M}) return{0}
}}
} else if (X==f64) { # bitarr ≡ f64arr
def T = [4]f64
def bulk = 4
f0:= broadcast{T, 0.0}
f1:= broadcast{T, 1.0}
maskedLoop{bulk, l, {i, M} => {
cw:= load{*u8 ~~ w, i>>1} >> cast_i{u8, 4*(i&1)}
cx:= load{*T ~~ x, i}
wu:= blend{f0, f1, broadcast{[4]u64, cw} << make{[4]u64,63,62,61,60}}
if (anyne{wu, cx, M}) return{0}
}}
} else { # bitarr ≡ i8/i16/i32arr
def T = [256/width{X}]X
def sh{c} = c << (width{X}-1)
def sh{c & X==u8} = T ~~ (to_el{u16,c}<<7)
# TODO compare with doing the comparison in vector registers
badBits:= broadcast{T, ~cast{X,1}}
maskedLoop{bulk, l, {i, M} => {
cw:= load{*ty_u{bulk} ~~ w, i}
cx:= load{*T ~~ x, i}
if (~andIsZero{M{cx}, badBits}) return{0}
if (anyne{promote{u64,getmask{sh{cx}}}, promote{u64,cw}, M}) return{0}
}}
1
}
} else { # everything not involving bitarrs
if (W==i8 and X==i8) l<<= d
def ww{gw, E} = [gw/width{E}]E
def fac = width{X}/width{W}
maskedLoop{bulk, l, {i, M} => {
# TODO update this to modern mask stuff
cw:= load{*ww{tern{fac==1, 256, 128}, W} ~~ (w + i*32/fac)}
cx:= load{*ww{256, X} ~~ x, i}
cwc:= cvt{W, ww{256, X}, cw}
if (anyne{cwc,cx,M}) return{0}
}}
}
1
}
'avx2_equal_1_1' = equal{u1, u1}
'avx2_equal_1_8' = equal{u1, u8}
'avx2_equal_1_16' = equal{u1, u16}
'avx2_equal_1_32' = equal{u1, u32}
'avx2_equal_1_f64' = equal{u1, f64}
'avx2_equal_8_8' = equal{i8, i8}
'avx2_equal_s8_16' = equal{i8, i16}
'avx2_equal_s8_32' = equal{i8, i32}
'avx2_equal_s16_32' = equal{i16, i32}
'avx2_equal_s8_f64' = equal{i8, f64}
'avx2_equal_s16_f64'= equal{i16, f64}
'avx2_equal_s32_f64'= equal{i32, f64}
'avx2_equal_f64_f64'= equal{f64, f64}
'avx2_equal_u8_16' = equal{u8, u16}
'avx2_equal_u8_32' = equal{u8, u32}
'avx2_equal_u16_32' = equal{u16, u32}