support arbitrary shape in rand.Range

This commit is contained in:
Josh Holland 2021-10-13 16:31:18 +01:00
parent 49622f59ce
commit a290e4dc0d

View File

@ -312,24 +312,49 @@ B rand_range_c1(B t, B x) {
return xv? m_f64(wy2u0k(rnd, xv)) : m_f64(wy2u01(rnd)); return xv? m_f64(wy2u0k(rnd, xv)) : m_f64(wy2u01(rnd));
} }
B rand_range_c2(B t, B w, B x) { B rand_range_c2(B t, B w, B x) {
usz am = o2s(w); usz am = 1;
i64 max = o2i64(x); i64 max = o2i64(x);
if (isArr(w)) {
if (rnk(w) != 1) thrM("(rand).Range: 𝕨 must be a valid shape");
SGetU(w);
for (u64 i = 0; i < a(w)->ia; i++) {
am *= o2s(GetU(w, i));
}
} else {
am = o2s(w);
}
RAND_START; RAND_START;
B r; Arr* r;
if (max<1) { if (max<1) {
if (max!=0) thrM("(rand).Range: 𝕩 cannot be negative"); if (max!=0) thrM("(rand).Range: 𝕩 cannot be negative");
f64* rp; r = m_f64arrv(&rp, am); f64* rp; r = m_f64arrp(&rp, am);
for (usz i = 0; i < am; i++) rp[i] = wy2u01(wyrand(&seed)); for (usz i = 0; i < am; i++) rp[i] = wy2u01(wyrand(&seed));
} else if (max > I32_MAX) { } else if (max > I32_MAX) {
if (max >= 1LL<<53) thrM("(rand).Range: 𝕩 must be less than 2⋆53"); if (max >= 1LL<<53) thrM("(rand).Range: 𝕩 must be less than 2⋆53");
f64* rp; r = m_f64arrv(&rp, am); f64* rp; r = m_f64arrp(&rp, am);
for (usz i = 0; i < am; i++) rp[i] = wy2u0k(wyrand(&seed), max); for (usz i = 0; i < am; i++) rp[i] = wy2u0k(wyrand(&seed), max);
} else { } else {
i32* rp; r = m_i32arrv(&rp, am); i32* rp; r = m_i32arrp(&rp, am);
for (usz i = 0; i < am; i++) rp[i] = wy2u0k(wyrand(&seed), max); for (usz i = 0; i < am; i++) rp[i] = wy2u0k(wyrand(&seed), max);
} }
RAND_END; RAND_END;
return r; if (isArr(w)) {
switch (a(w)->ia) {
case 0: { arr_shAlloc(r, 0); break; }
case 1: { arr_shVec(r); break; }
default: {
usz* sh = arr_shAlloc(r, a(w)->ia);
SGetU(w);
for (usz i = 0; i < a(w)->ia; i++) {
sh[i] = o2s(GetU(w, i));
}
}
}
} else {
arr_shVec(r);
}
dec(w);
return taga(r);
} }
B rand_deal_c1(B t, B x) { B rand_deal_c1(B t, B x) {