Implement scalar extension for •_bit

This commit is contained in:
Marshall Lochbaum 2022-10-11 21:36:07 -04:00
parent 022ef64f6b
commit 698c902564

View File

@ -1267,22 +1267,33 @@ B bitop2(B f, B w, B x, enum BitOp2 op, char* name) {
ow = t[0]; rw = t[1]; xw = t[2]; ww = t[3];
}
ur xr;
if (!isArr(x) || (xr=RNK(x))<1) thrF("•bit._%U: 𝕩 must have rank at least 1", name);
if (!isArr(w) || RNK(w) != xr ) thrF("•bit._%U: 𝕨 must have rank equal to 𝕩", name);
usz* sh = SH(x);
usz* wsh = SH(w);
for (usz i=0; i<xr-1; i++) if (sh[i]!=wsh[i]) thrF("•bit._%U: 𝕨 and 𝕩 leading shapes must match", name);
if (isAtm(x)) x = m_atomUnit(x);
if (isAtm(w)) w = m_atomUnit(w);
ur wr=RNK(w); usz* wsh = SH(w); u64 s = wr==0? ww : ww*(u64)wsh[wr-1];
ur xr=RNK(x); usz* sh = SH(x); u64 t = xr==0? xw : xw*(u64) sh[xr-1];
bool negw = 0; // Negate 𝕨 to subtract from 𝕩
bool noextend = wr == xr && s == t;
if (wr==xr && xr==0) thrF("•bit._%U: some argument must have rank at least 1", name);
if (noextend) {
for (usz i=0; i<xr-1; i++) if (sh[i]!=wsh[i]) thrF("•bit._%U: 𝕨 and 𝕩 leading shapes must match", name);
} else {
if (wr>1 || s!=ow || xr==0) { // Need to extend 𝕩
if (xr>1 || t!=ow || wr==0) {
if (wr!=xr && wr>1 && xr>1) thrF("•bit._%U: 𝕨 and 𝕩 must have equal ranks if more than 1", name);
thrF("•bit._%U: 𝕨 or 𝕩 1-cell width must equal operation width if extended", name);
}
{ B t=w; w=x; x=t; }
{ usz t=ww; ww=xw; xw=t; }
negw=op==op_sub; if (negw) op=op_add;
t = s; xr = wr; sh = wsh;
}
}
usz rws = CTZ(rw);
usz xws = CTZ(xw);
u64 s = (u64)sh[xr-1] << xws;
if (s != ww*(u64)wsh[xr-1]) thrF("•bit._%U: 𝕨 and 𝕩 1-cell widths must match", name);
u64 n = IA(x) << xws;
u64 rl = s >> rws;
if ((s & (ow-1)) || (rl<<rws != s)) thrF("•bit._%U: incompatible lengths", name);
u64 n = IA(x) << CTZ(xw);
u64 rl = t >> rws;
if ((t & (ow-1)) || (rl<<rws != t)) thrF("•bit._%U: incompatible lengths", name);
if (rl>=USZ_MAX) thrF("•bit._%U: output too large", name);
w = convert((CastType){ ww, 0 }, w);
x = convert((CastType){ xw, 0 }, x);
u8 rt = typeOfCast((CastType){ rw, 0 });
@ -1292,25 +1303,48 @@ B bitop2(B f, B w, B x, enum BitOp2 op, char* name) {
u64* wp = tyany_ptr(w);
u64* xp = tyany_ptr(x);
u64* rp = tyany_ptr(r);
switch (op) { default: UD;
#define OP(O,P) case op_##O: { \
#define CASES(O,Q,P) case op_##O: \
switch(ow) { default: thrF("•bit._%U: unhandled width %s", name, ow); \
CASE(8,Q,P) CASE(16,Q,P) CASE(32,Q,P) CASE(64,Q,P) \
} break;
#define SWITCH \
switch (op) { default: UD; \
BINOP(and,&) BINOP(or,|) BINOP(xor,^) \
CASES(add,u,+) CASES(sub,u,-) CASES(mul,i,*) \
}
if (noextend) {
#define BINOP(O,P) case op_##O: { \
usz l = n/64; NOUNROLL for (usz i=0; i<l; i++) rp[i] = wp[i] P xp[i]; \
usz q = (-n)%64; if (q) rp[l] ^= (~(u64)0 >> q) & (rp[l]^(wp[l] P xp[l])); \
} break;
OP(and,&) OP(or,|) OP(xor,^)
#undef OP
#define CASE(W, Q, P) case W: \
NOUNROLL for (usz i=0; i<n/W; i++) \
((Q##W*)rp)[i] = ((Q##W*)wp)[i] P ((Q##W*)xp)[i]; \
break;
#define OP(O,Q,P) case op_##O: \
switch(ow) { default: thrF("•bit._%U: unhandled width %s", name, ow); \
CASE(8,Q,P) CASE(16,Q,P) CASE(32,Q,P) CASE(64,Q,P) \
SWITCH
#undef BINOP
#undef CASE
} else {
u64 wn; if (negw) { wn=-*wp; wp=&wn; }
#define BINOP(O,P) case op_##O: { \
if (ow>64) thrF("•bit._%U: scalar extension with width over 64 unhandled", name); \
u64 wv = *wp & (~(u64)0>>(64-ow)); \
for (usz tw=ow; tw<64; tw*=2) wv|=wv<<tw; \
usz l = n/64; NOUNROLL for (usz i=0; i<l; i++) rp[i] = wv P xp[i]; \
usz q = (-n)%64; if (q) rp[l] ^= (~(u64)0 >> q) & (rp[l]^(wv P xp[l])); \
} break;
OP(add,u,+) OP(sub,u,-) OP(mul,i,*)
#undef OP
#define CASE(W, Q, P) case W: { \
Q##W wv = *(Q##W*)wp; \
NOUNROLL for (usz i=0; i<n/W; i++) \
((Q##W*)rp)[i] = wv P ((Q##W*)xp)[i]; \
} break;
SWITCH
#undef BINOP
#undef CASE
}
#undef CASES
#undef SWITCH
set_bit_result(r, rt, xr, rl, sh);
decG(w); decG(x);
return r;