Merge pull request #22 from mlochbaum/rank2

Rank2
This commit is contained in:
dzaima 2022-05-30 01:28:39 +03:00 committed by GitHub
commit b2b0e4f92a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 156 additions and 26 deletions

View File

@ -210,40 +210,45 @@ static B m2c2(B t, B f, B g, B w, B x) { // consumes w,x
return r;
}
static f64 req_whole(f64 f) {
if (floor(f)!=f) thrM("⎉: 𝕘 was a fractional number");
return f;
}
static usz check_rank_vec(B g) {
if (!isArr(g)) thrM("⎉: Invalid 𝔾 result");
usz gia = a(g)->ia;
if (!(gia>=1 && gia<=3)) thrM("⎉: 𝔾 result must have 1 to 3 elements");
SGetU(g)
if (TI(g,elType)>=el_f64) for (i32 i = 0; i < gia; i++) req_whole(o2f(GetU(g,i)));
return gia;
}
static ur cell_rank(f64 r, f64 k) { // ⎉k over arg rank r
return k<0? (k+r<0? 0 : k+r) : (k>r? r : k);
}
B rank_c1(Md2D* d, B x) { B f = d->f; B g = d->g;
f64 kf;
bool gf = isFun(g);
if (RARE(gf)) g = c1(g, inc(x));
if (LIKELY(isNum(g))) {
kf = o2fu(g);
} else if (isArr(g)) {
usz gia = a(g)->ia;
if (!(gia>=1 && gia<=3)) thrM("⎉: 𝔾 result must have 1 to 3 elements");
SGetU(g)
if (!elNum(TI(g,elType))) for (i32 i = 0; i < gia; i++) o2f(GetU(g,i));
kf = GetU(g, gia==2).f;
} else thrM("⎉: Invalid 𝔾 result");
kf = req_whole(o2fu(g));
} else {
usz gia = check_rank_vec(g);
SGetU(g); kf = GetU(g, gia==2).f;
}
if (gf) dec(g);
i32 k = kf;
if (isAtm(x) || rnk(x)==0) {
if (floor(kf)!=kf) thrM("⎉: 𝕘 was a fractional number");
B r = c1(f, x);
return isAtm(r)? m_atomUnit(r) : r;
}
i32 xr = rnk(x);
usz* xsh = a(x)->sh;
if (k!=kf) {
if (floor(kf)!=kf) thrM("⎉: 𝕘 was a fractional number");
k = kf>0? 0 : xr;
} else {
k = k<0? (k+xr<0? xr : xr-(k+xr)) : (k>xr? 0 : xr-k);
}
ur cr = cell_rank(xr, kf);
i32 k = xr - cr;
if (Q_BI(f,lt) && a(x)->ia!=0 && rnk(x)>1) return toKCells(x, k);
usz* xsh = a(x)->sh;
usz cam = 1; for (usz i = 0; i < k; i++) cam*= xsh[i];
usz csz = 1; for (usz i = k; i < xr; i++) csz*= xsh[i];
ur cr = xr-k;
ShArr* csh;
if (cr>1) {
csh = m_shArr(cr);
@ -268,8 +273,133 @@ B rank_c1(Md2D* d, B x) { B f = d->f; B g = d->g;
return bqn_merge(HARR_O(r).b);
}
extern B rt_rank;
B rank_c2(Md2D* d, B w, B x) { B f = d->f; B g = d->g; // TODO
return m2c2(rt_rank, f, g, w, x);
B rank_c2(Md2D* d, B w, B x) { B f = d->f; B g = d->g;
f64 wf, xf;
bool gf = isFun(g);
if (RARE(gf)) g = c2(g, inc(w), inc(x));
if (LIKELY(isNum(g))) {
wf = xf = req_whole(o2fu(g));
} else {
usz gia = check_rank_vec(g);
SGetU(g);
wf = GetU(g, gia<2?0:gia-2).f;
xf = GetU(g, gia-1).f;
}
ur wr = isAtm(w) ? 0 : rnk(w); ur wc = cell_rank(wr, wf);
ur xr = isAtm(x) ? 0 : rnk(x); ur xc = cell_rank(xr, xf);
B r;
if (wr == wc) {
if (xr == xc) {
if (gf) dec(g);
r = c2(f, w, x);
return isAtm(r)? m_atomUnit(r) : r;
} else {
i32 k = xr - xc;
usz* xsh = a(x)->sh;
usz cam = 1; for (usz i = 0; i < k; i++) cam*= xsh[i];
usz csz = 1; for (usz i = k; i < xr; i++) csz*= xsh[i];
if (cam == 0) { return m2c2(rt_rank, f, g, w, x); } // TODO
ShArr* csh;
if (xc>1) { csh=m_shArr(xc); shcpy(csh->a, xsh+k, xc); }
BSS2A slice = TI(x,slice);
M_HARR(r, cam);
usz p = 0;
for (usz i = 0; i < cam; i++) {
Arr* s = slice(inc(x), p, csz); arr_shSetI(s, xc, csh);
HARR_ADD(r, i, c2(f, inc(w), taga(s)));
p+= csz;
}
if (xc>1) ptr_dec(csh);
usz* rsh = HARR_FA(r, k);
if (k>1) shcpy(rsh, xsh, k);
dec(w); decG(x); r = HARR_O(r).b;
}
} else if (xr == xc) {
i32 k = wr - wc;
usz* wsh = a(w)->sh;
usz cam = 1; for (usz i = 0; i < k; i++) cam*= wsh[i];
usz csz = 1; for (usz i = k; i < wr; i++) csz*= wsh[i];
if (cam == 0) { return m2c2(rt_rank, f, g, w, x); } // TODO
ShArr* csh;
if (wc>1) { csh=m_shArr(wc); shcpy(csh->a, wsh+k, wc); }
BSS2A slice = TI(w,slice);
M_HARR(r, cam);
usz p = 0;
for (usz i = 0; i < cam; i++) {
Arr* s = slice(inc(w), p, csz); arr_shSetI(s, wc, csh);
HARR_ADD(r, i, c2(f, taga(s), inc(x)));
p+= csz;
}
if (wc>1) ptr_dec(csh);
usz* rsh = HARR_FA(r, k);
if (k>1) shcpy(rsh, wsh, k);
decG(w); dec(x); r = HARR_O(r).b;
} else {
i32 wk = wr - wc; usz* wsh = a(w)->sh;
i32 xk = xr - xc; usz* xsh = a(x)->sh;
i32 k=wk, zk=xk; if (k>zk) { i32 t=k; k=zk; zk=t; }
usz* zsh = wk>xk? wsh : xsh;
usz cam = 1; for (usz i = 0; i < k; i++) {
usz wl = wsh[i], xl = xsh[i];
if (wl != xl) thrF("⎉: Argument frames don't agree (%H ≡ ≢𝕨, %H ≡ ≢𝕩, common frame of %s axes)", w, x, k);
cam*= wsh[i];
}
usz ext = 1; for (usz i = k; i < zk; i++) ext*= zsh[i];
usz wsz = 1; for (usz i = wk; i < wr; i++) wsz*= wsh[i];
usz xsz = 1; for (usz i = xk; i < xr; i++) xsz*= xsh[i];
cam *= ext;
if (cam == 0) { return m2c2(rt_rank, f, g, w, x); } // TODO
ShArr* wcs; if (wc>1) { wcs=m_shArr(wc); shcpy(wcs->a, wsh+wk, wc); }
ShArr* xcs; if (xc>1) { xcs=m_shArr(xc); shcpy(xcs->a, xsh+xk, xc); }
BSS2A wslice = TI(w,slice);
BSS2A xslice = TI(x,slice);
M_HARR(r, cam);
usz wp = 0, xp = 0;
#define CELL(wx) \
Arr* wx##s = wx##slice(inc(wx), wx##p, wx##sz); \
arr_shSetI(wx##s, wx##c, wx##cs); \
wx##p+= wx##sz
#define F(W,X) HARR_ADD(r, i, c2(f, W, X))
if (ext == 1) {
for (usz i = 0; i < cam; i++) {
CELL(w); CELL(x); F(taga(ws), taga(xs));
}
} else if (wk < xk) {
for (usz i = 0; i < cam; ) {
CELL(w); B wb=taga(ws);
for (usz e = i+ext; i < e; i++) { CELL(x); F(inc(wb), taga(xs)); }
dec(wb);
}
} else {
for (usz i = 0; i < cam; ) {
CELL(x); B xb=taga(xs);
for (usz e = i+ext; i < e; i++) { CELL(w); F(taga(ws), inc(xb)); }
dec(xb);
}
}
#undef CELL
#undef F
if (wc>1) ptr_dec(wcs);
if (xc>1) ptr_dec(xcs);
usz* rsh = HARR_FA(r, zk);
if (zk>1) shcpy(rsh, zsh, zk);
decG(w); decG(x); r = HARR_O(r).b;
}
if (gf) dec(g);
return bqn_merge(r);
}

View File

@ -8,8 +8,8 @@ B toCells(B x) {
usz cam = a(x)->sh[0];
usz csz = arr_csz(x);
BSS2A slice = TI(x,slice);
usz p = 0;
M_HARR(r, cam)
usz p = 0;
if (rnk(x)==2) {
for (usz i = 0; i < cam; i++) {
Arr* s = slice(inc(x), p, csz); arr_shVec(s);
@ -20,7 +20,7 @@ B toCells(B x) {
usz cr = rnk(x)-1;
ShArr* csh = m_shArr(cr);
usz* xsh = a(x)->sh;
for (u64 i = 0; i < cr; i++) csh->a[i] = xsh[i+1];
shcpy(csh->a, xsh+1, cr);
for (usz i = 0; i < cam; i++) {
Arr* s = slice(inc(x), p, csz); arr_shSetI(s, cr, csh);
HARR_ADD(r, i, taga(s));
@ -41,12 +41,12 @@ B toKCells(B x, ur k) {
ShArr* csh;
if (cr>1) {
csh = m_shArr(cr);
for (i32 i = 0; i < cr; i++) csh->a[i] = xsh[i+k];
shcpy(csh->a, xsh+k, cr);
}
usz p = 0;
M_HARR(r, cam)
BSS2A slice = TI(x,slice);
M_HARR(r, cam);
usz p = 0;
for (usz i = 0; i < cam; i++) {
Arr* s = slice(inc(x), p, csz); arr_shSetI(s, cr, csh);
HARR_ADD(r, i, taga(s));
@ -54,7 +54,7 @@ B toKCells(B x, ur k) {
}
if (cr>1) ptr_dec(csh);
usz* rsh = HARR_FA(r, k);
if (rsh) for (i32 i = 0; i < k; i++) rsh[i] = xsh[i];
if (rsh) shcpy(rsh, xsh, k);
decG(x);
return HARR_O(r).b;
}