This commit is contained in:
dzaima 2021-04-27 14:38:32 +03:00
parent 271479c511
commit a6d15846f6
4 changed files with 30 additions and 28 deletions

24
src/h.h
View File

@ -278,21 +278,17 @@ void arr_shVec(B x, usz ia) {
srnk(x, 1); srnk(x, 1);
a(x)->sh = &a(x)->ia; a(x)->sh = &a(x)->ia;
} }
bool gotShape[t_COUNT]; usz* arr_shAllocR(B x, ur r) { // allocates shape, sets rank
usz* arr_shAllocI(B x, usz ia, ur r) { srnk(x,r);
if (r>1) return a(x)->sh = ((ShArr*)mm_allocN(fsizeof(ShArr, a, usz, r), t_shape))->a;
a(x)->sh = &a(x)->ia;
return 0;
}
usz* arr_shAllocI(B x, usz ia, ur r) { // allocates shape, sets ia,rank
a(x)->ia = ia; a(x)->ia = ia;
srnk(x,r); return arr_shAllocR(x, r);
if (r>1) return a(x)->sh = ((ShArr*)mm_allocN(fsizeof(ShArr, a, usz, r), t_shape))->a;
a(x)->sh = &a(x)->ia;
return 0;
} }
usz* arr_shAllocR(B x, ur r) { // allocates shape, sets rank, leaves ia unchanged void arr_shCopy(B n, B o) { // copy shape,rank,ia from o to n
srnk(x,r);
if (r>1) return a(x)->sh = ((ShArr*)mm_allocN(fsizeof(ShArr, a, usz, r), t_shape))->a;
a(x)->sh = &a(x)->ia;
return 0;
}
void arr_shCopy(B n, B o) { // copy shape from o to n
assert(isArr(o)); assert(isArr(o));
a(n)->ia = a(o)->ia; a(n)->ia = a(o)->ia;
ur r = srnk(n,rnk(o)); ur r = srnk(n,rnk(o));
@ -352,7 +348,7 @@ typedef struct TypeInfo {
BS2B getU; // like get, but doesn't increment result (mostly equivalent to `B t=get(…); dec(t); t`) BS2B getU; // like get, but doesn't increment result (mostly equivalent to `B t=get(…); dec(t); t`)
BB2B m1_d; // consume all args; (m, f) BB2B m1_d; // consume all args; (m, f)
BBB2B m2_d; // consume all args; (m, f, g) BBB2B m2_d; // consume all args; (m, f, g)
BS2B slice; // consumes; create slice from given starting position; add ia, rank, shape yourself BS2B slice; // consumes; create slice from given starting position; add ia, rank, shape yourself; may not actually be a Slice object
B2b canStore; // doesn't consume B2b canStore; // doesn't consume
B2B identity; // return identity element of this function; doesn't consume B2B identity; // return identity element of this function; doesn't consume

View File

@ -21,19 +21,23 @@ HArr_p m_harrv(usz ia) {
arr_shVec(r, ia); arr_shVec(r, ia);
return harr_parts(r); return harr_parts(r);
} }
HArr_p m_harrc(B x) { assert(isArr(x)); HArr_p m_harrc(B x) { assert(isArr(x));
B r = m_arr(fsizeof(HArr,a,B,a(x)->ia), t_harr); B r = m_arr(fsizeof(HArr,a,B,a(x)->ia), t_harr);
arr_shCopy(r, x); arr_shCopy(r, x);
return harr_parts(r); return harr_parts(r);
} }
HArr_p m_harrp(usz ia) { // doesn't write shape/rank HArr_p m_harrp(usz ia) { // doesn't write shape/rank
B r = m_arr(fsizeof(HArr,a,B,ia), t_harr); B r = m_arr(fsizeof(HArr,a,B,ia), t_harr);
a(r)->ia = ia; a(r)->ia = ia;
return harr_parts(r); return harr_parts(r);
} }
B m_hunit(B x) {
HArr_p r = m_harrp(1);
arr_shAllocR(r.b, 0);
r.a[0] = x;
return r.b;
}
B* harr_ptr(B x) { VT(x,t_harr); return c(HArr,x)->a; } B* harr_ptr(B x) { VT(x,t_harr); return c(HArr,x)->a; }

View File

@ -4,10 +4,8 @@ B tbl_c1(B d, B x) { B f = c(Md1D,d)->f;
return eachm(f, x); return eachm(f, x);
} }
B tbl_c2(B d, B w, B x) { B f = c(Md1D,d)->f; B tbl_c2(B d, B w, B x) { B f = c(Md1D,d)->f;
if (isAtm(w) | isAtm(x)) { if (isAtm(w)) w = m_hunit(w);
if (isAtm(w)) w = m_unit(w); if (isAtm(x)) x = m_hunit(x);
if (isAtm(x)) x = m_unit(x);
}
usz wia = a(w)->ia; ur wr = rnk(w); usz wia = a(w)->ia; ur wr = rnk(w);
usz xia = a(x)->ia; ur xr = rnk(x); usz xia = a(x)->ia; ur xr = rnk(x);
usz ria = wia*xia; ur rr = wr+xr; usz ria = wia*xia; ur rr = wr+xr;

View File

@ -6,15 +6,19 @@ typedef struct BFn {
} BFn; } BFn;
B eachd_fn(BBB2B f, B fo, B w, B x) { // consumes w,x; assumes at least one is array B eachd_fn(BBB2B f, B fo, B w, B x) { // consumes w,x; assumes at least one is array
usz wia; ur wr; BS2B wget; if (!isArr(w)) w = m_hunit(w);
usz xia; ur xr; BS2B xget; if (!isArr(x)) x = m_hunit(x);
if (isArr(w)) { wia = a(w)->ia; wr = rnk(w); wget = TI(w).get; } else { wia=1; wr=0; wget=def_get; } ur wr = rnk(w); BS2B wget = TI(w).get;
if (isArr(x)) { xia = a(x)->ia; xr = rnk(x); xget = TI(x).get; } else { xia=1; xr=0; xget=def_get; } ur xr = rnk(x); BS2B xget = TI(x).get;
bool wg = wr>xr; bool wg = wr>xr;
ur rM = wg? wr : xr; ur rM = wg? wr : xr;
ur rm = wg? xr : wr; ur rm = wg? xr : wr;
if (rM==0) { B r = f(fo, wget(w,0), xget(x,0)); dec(w); dec(x); return m_unit(r); } if (rM==0) {
if (isArr(w) & isArr(x) && !eqShPrefix(a(w)->sh, a(x)->sh, rm)) thrM("Mapping: Expected equal shape prefix"); B r = f(fo, wget(w,0), xget(x,0));
dec(w); dec(x);
return m_hunit(r);
}
if (rm && !eqShPrefix(a(w)->sh, a(x)->sh, rm)) thrM("Mapping: Expected equal shape prefix");
bool rw = rM==wr && ((v(w)->type==t_harr) & reusable(w)); // v(…) is safe as rank>0 bool rw = rM==wr && ((v(w)->type==t_harr) & reusable(w)); // v(…) is safe as rank>0
bool rx = rM==xr && ((v(x)->type==t_harr) & reusable(x)); bool rx = rM==xr && ((v(x)->type==t_harr) & reusable(x));
if (rw|rx && (wr==xr | rm==0)) { if (rw|rx && (wr==xr | rm==0)) {
@ -111,7 +115,7 @@ B eachm_fn(BB2B f, B fo, B x) { // consumes x; x must be array
return rH.b; return rH.b;
} }
B eachm(B f, B x) { // complete F¨ x B eachm(B f, B x) { // complete F¨ x
if (!isArr(x)) return m_unit(c1(f, x)); if (!isArr(x)) return m_hunit(c1(f, x));
if (isFun(f)) return eachm_fn(c(Fun,f)->c1, f, x); if (isFun(f)) return eachm_fn(c(Fun,f)->c1, f, x);
if (isMd(f)) if (!isArr(x) || a(x)->ia) { decR(x); thrM("Calling a modifier"); } if (isMd(f)) if (!isArr(x) || a(x)->ia) { decR(x); thrM("Calling a modifier"); }
@ -122,7 +126,7 @@ B eachm(B f, B x) { // complete F¨ x
} }
B eachd(B f, B w, B x) { // complete w F¨ x B eachd(B f, B w, B x) { // complete w F¨ x
if (!isArr(w) & !isArr(x)) return m_unit(c2(f, w, x)); if (!isArr(w) & !isArr(x)) return m_hunit(c2(f, w, x));
if (isFun(f)) return eachd_fn(c(Fun,f)->c2, f, w, x); if (isFun(f)) return eachd_fn(c(Fun,f)->c2, f, w, x);
if (isArr(w) && isArr(x)) { if (isArr(w) && isArr(x)) {
ur mr = rnk(w); if(rnk(w)<mr) mr = rnk(w); ur mr = rnk(w); if(rnk(w)<mr) mr = rnk(w);