Implement scalar extension for •_bit

This commit is contained in:
Marshall Lochbaum 2022-10-11 21:36:07 -04:00
parent 022ef64f6b
commit 698c902564

View File

@ -1267,20 +1267,31 @@ B bitop2(B f, B w, B x, enum BitOp2 op, char* name) {
ow = t[0]; rw = t[1]; xw = t[2]; ww = t[3]; ow = t[0]; rw = t[1]; xw = t[2]; ww = t[3];
} }
ur xr; if (isAtm(x)) x = m_atomUnit(x);
if (!isArr(x) || (xr=RNK(x))<1) thrF("•bit._%U: 𝕩 must have rank at least 1", name); if (isAtm(w)) w = m_atomUnit(w);
if (!isArr(w) || RNK(w) != xr ) thrF("•bit._%U: 𝕨 must have rank equal to 𝕩", name); ur wr=RNK(w); usz* wsh = SH(w); u64 s = wr==0? ww : ww*(u64)wsh[wr-1];
usz* sh = SH(x); ur xr=RNK(x); usz* sh = SH(x); u64 t = xr==0? xw : xw*(u64) sh[xr-1];
usz* wsh = SH(w); bool negw = 0; // Negate 𝕨 to subtract from 𝕩
for (usz i=0; i<xr-1; i++) if (sh[i]!=wsh[i]) thrF("•bit._%U: 𝕨 and 𝕩 leading shapes must match", name); bool noextend = wr == xr && s == t;
if (wr==xr && xr==0) thrF("•bit._%U: some argument must have rank at least 1", name);
if (noextend) {
for (usz i=0; i<xr-1; i++) if (sh[i]!=wsh[i]) thrF("•bit._%U: 𝕨 and 𝕩 leading shapes must match", name);
} else {
if (wr>1 || s!=ow || xr==0) { // Need to extend 𝕩
if (xr>1 || t!=ow || wr==0) {
if (wr!=xr && wr>1 && xr>1) thrF("•bit._%U: 𝕨 and 𝕩 must have equal ranks if more than 1", name);
thrF("•bit._%U: 𝕨 or 𝕩 1-cell width must equal operation width if extended", name);
}
{ B t=w; w=x; x=t; }
{ usz t=ww; ww=xw; xw=t; }
negw=op==op_sub; if (negw) op=op_add;
t = s; xr = wr; sh = wsh;
}
}
usz rws = CTZ(rw); usz rws = CTZ(rw);
usz xws = CTZ(xw); u64 n = IA(x) << CTZ(xw);
u64 s = (u64)sh[xr-1] << xws; u64 rl = t >> rws;
if (s != ww*(u64)wsh[xr-1]) thrF("•bit._%U: 𝕨 and 𝕩 1-cell widths must match", name); if ((t & (ow-1)) || (rl<<rws != t)) thrF("•bit._%U: incompatible lengths", name);
u64 n = IA(x) << xws;
u64 rl = s >> rws;
if ((s & (ow-1)) || (rl<<rws != s)) thrF("•bit._%U: incompatible lengths", name);
if (rl>=USZ_MAX) thrF("•bit._%U: output too large", name); if (rl>=USZ_MAX) thrF("•bit._%U: output too large", name);
w = convert((CastType){ ww, 0 }, w); w = convert((CastType){ ww, 0 }, w);
@ -1292,25 +1303,48 @@ B bitop2(B f, B w, B x, enum BitOp2 op, char* name) {
u64* wp = tyany_ptr(w); u64* wp = tyany_ptr(w);
u64* xp = tyany_ptr(x); u64* xp = tyany_ptr(x);
u64* rp = tyany_ptr(r); u64* rp = tyany_ptr(r);
switch (op) { default: UD;
#define OP(O,P) case op_##O: { \ #define CASES(O,Q,P) case op_##O: \
switch(ow) { default: thrF("•bit._%U: unhandled width %s", name, ow); \
CASE(8,Q,P) CASE(16,Q,P) CASE(32,Q,P) CASE(64,Q,P) \
} break;
#define SWITCH \
switch (op) { default: UD; \
BINOP(and,&) BINOP(or,|) BINOP(xor,^) \
CASES(add,u,+) CASES(sub,u,-) CASES(mul,i,*) \
}
if (noextend) {
#define BINOP(O,P) case op_##O: { \
usz l = n/64; NOUNROLL for (usz i=0; i<l; i++) rp[i] = wp[i] P xp[i]; \ usz l = n/64; NOUNROLL for (usz i=0; i<l; i++) rp[i] = wp[i] P xp[i]; \
usz q = (-n)%64; if (q) rp[l] ^= (~(u64)0 >> q) & (rp[l]^(wp[l] P xp[l])); \ usz q = (-n)%64; if (q) rp[l] ^= (~(u64)0 >> q) & (rp[l]^(wp[l] P xp[l])); \
} break; } break;
OP(and,&) OP(or,|) OP(xor,^)
#undef OP
#define CASE(W, Q, P) case W: \ #define CASE(W, Q, P) case W: \
NOUNROLL for (usz i=0; i<n/W; i++) \ NOUNROLL for (usz i=0; i<n/W; i++) \
((Q##W*)rp)[i] = ((Q##W*)wp)[i] P ((Q##W*)xp)[i]; \ ((Q##W*)rp)[i] = ((Q##W*)wp)[i] P ((Q##W*)xp)[i]; \
break; break;
#define OP(O,Q,P) case op_##O: \ SWITCH
switch(ow) { default: thrF("•bit._%U: unhandled width %s", name, ow); \ #undef BINOP
CASE(8,Q,P) CASE(16,Q,P) CASE(32,Q,P) CASE(64,Q,P) \ #undef CASE
} else {
u64 wn; if (negw) { wn=-*wp; wp=&wn; }
#define BINOP(O,P) case op_##O: { \
if (ow>64) thrF("•bit._%U: scalar extension with width over 64 unhandled", name); \
u64 wv = *wp & (~(u64)0>>(64-ow)); \
for (usz tw=ow; tw<64; tw*=2) wv|=wv<<tw; \
usz l = n/64; NOUNROLL for (usz i=0; i<l; i++) rp[i] = wv P xp[i]; \
usz q = (-n)%64; if (q) rp[l] ^= (~(u64)0 >> q) & (rp[l]^(wv P xp[l])); \
} break; } break;
OP(add,u,+) OP(sub,u,-) OP(mul,i,*) #define CASE(W, Q, P) case W: { \
#undef OP Q##W wv = *(Q##W*)wp; \
NOUNROLL for (usz i=0; i<n/W; i++) \
((Q##W*)rp)[i] = wv P ((Q##W*)xp)[i]; \
} break;
SWITCH
#undef BINOP
#undef CASE #undef CASE
} }
#undef CASES
#undef SWITCH
set_bit_result(r, rt, xr, rl, sh); set_bit_result(r, rt, xr, rl, sh);
decG(w); decG(x); decG(w); decG(x);
return r; return r;