use shProd more

This commit is contained in:
dzaima 2022-05-30 02:07:25 +03:00
parent cc44b3e57b
commit 79bd6fc689
6 changed files with 19 additions and 23 deletions

View File

@ -388,18 +388,18 @@ static B m1c2(B t, B f, B w, B x) { // consumes w,x
return r;
}
#define S_SLICES(X) \
BSS2A X##_slc = TI(X,slice); \
usz X##_csz = 1; \
usz X##_cr = rnk(X)-1; \
ShArr* X##_csh; \
if (X##_cr>1) { \
X##_csh = m_shArr(X##_cr); \
for (usz i = 0; i < X##_cr; i++) { \
usz v = a(X)->sh[i+1]; \
X##_csz*= v; \
X##_csh->a[i] = v; \
} \
#define S_SLICES(X) \
BSS2A X##_slc = TI(X,slice); \
usz X##_csz = 1; \
usz X##_cr = rnk(X)-1; \
ShArr* X##_csh; \
if (X##_cr>1) { \
X##_csh = m_shArr(X##_cr); \
NOUNROLL for (usz i = 0; i < X##_cr; i++) { \
usz v = a(X)->sh[i+1]; \
X##_csz*= v; \
X##_csh->a[i] = v; \
} \
} else if (X##_cr!=0) X##_csz*= a(X)->sh[1];
#define SLICE(X, S) ({ Arr* r_ = X##_slc(inc(X), S, X##_csz); arr_shSetI(r_, X##_cr, X##_csh); taga(r_); })

View File

@ -16,9 +16,8 @@ B select_c1(B t, B x) {
ur xr = rnk(x);
if (xr==0) thrM("⊏: Argument cannot be rank 0");
if (a(x)->sh[0]==0) thrF("⊏: Argument shape cannot start with 0 (%H ≡ ≢𝕩)", x);
usz ia = 1;
for (i32 i = 1; i < xr; i++) ia*= a(x)->sh[i];
Arr* r = TI(x,slice)(inc(x),0, ia);
usz ia = shProd(a(x)->sh, 1, xr);
Arr* r = TI(x,slice)(inc(x), 0, ia);
usz* sh = arr_shAlloc(r, xr-1);
if (sh) for (i32 i = 1; i < xr; i++) sh[i-1] = a(x)->sh[i];
decG(x);

View File

@ -1132,8 +1132,7 @@ B transp_c1(B t, B x) {
usz ia = a(x)->ia;
usz* xsh = a(x)->sh;
usz h = xsh[0];
usz w = xsh[1];
for (usz i = 2; RARE(i < xr); i++) w*= a(x)->sh[i];
usz w = xsh[1] * shProd(a(x)->sh, 2, xr);
Arr* r;
usz xi = 0;

View File

@ -1063,7 +1063,7 @@ B bitcast_impl(B el0, B el1, B x) {
sh = zsh;
}
sh[xr-1]=zl;
usz ia=zl; for (usz i=0;i<xr-1;i++)ia*=sh[i]; a(r)->ia=ia;
a(r)->ia = zl*shProd(sh, 0, xr-1);
}
return r;
}

View File

@ -35,8 +35,8 @@ B toKCells(B x, ur k) {
assert(isArr(x) && k<=rnk(x) && k>=0);
ur xr = rnk(x); usz* xsh = a(x)->sh;
ur cr = xr-k;
usz cam = 1; for (i32 i = 0; i < k ; i++) cam*= xsh[i];
usz csz = 1; for (i32 i = k; i < xr; i++) csz*= xsh[i];
usz cam = shProd(xsh, 0, k);
usz csz = shProd(xsh, k, xr);
ShArr* csh;
if (cr>1) {

View File

@ -846,9 +846,7 @@ void g_pst(void) { vm_pstLive(); fflush(stdout); fflush(stderr); }
Arr* a = (Arr*)x;
if (prnk(x)<=1) assert(a->sh == &a->ia);
else {
u64 shProduct = 1;
for (usz i=0; i < prnk(x); i++) shProduct*= a->sh[i];
assert(shProduct == a->ia);
assert(shProd(a->sh, 0, prnk(x)) == a->ia);
VALIDATE(tag(shObjP(x),OBJ_TAG));
}
}