use shProd more

This commit is contained in:
dzaima 2022-05-30 02:07:25 +03:00
parent cc44b3e57b
commit 79bd6fc689
6 changed files with 19 additions and 23 deletions

View File

@ -388,18 +388,18 @@ static B m1c2(B t, B f, B w, B x) { // consumes w,x
return r; return r;
} }
#define S_SLICES(X) \ #define S_SLICES(X) \
BSS2A X##_slc = TI(X,slice); \ BSS2A X##_slc = TI(X,slice); \
usz X##_csz = 1; \ usz X##_csz = 1; \
usz X##_cr = rnk(X)-1; \ usz X##_cr = rnk(X)-1; \
ShArr* X##_csh; \ ShArr* X##_csh; \
if (X##_cr>1) { \ if (X##_cr>1) { \
X##_csh = m_shArr(X##_cr); \ X##_csh = m_shArr(X##_cr); \
for (usz i = 0; i < X##_cr; i++) { \ NOUNROLL for (usz i = 0; i < X##_cr; i++) { \
usz v = a(X)->sh[i+1]; \ usz v = a(X)->sh[i+1]; \
X##_csz*= v; \ X##_csz*= v; \
X##_csh->a[i] = v; \ X##_csh->a[i] = v; \
} \ } \
} else if (X##_cr!=0) X##_csz*= a(X)->sh[1]; } else if (X##_cr!=0) X##_csz*= a(X)->sh[1];
#define SLICE(X, S) ({ Arr* r_ = X##_slc(inc(X), S, X##_csz); arr_shSetI(r_, X##_cr, X##_csh); taga(r_); }) #define SLICE(X, S) ({ Arr* r_ = X##_slc(inc(X), S, X##_csz); arr_shSetI(r_, X##_cr, X##_csh); taga(r_); })

View File

@ -16,9 +16,8 @@ B select_c1(B t, B x) {
ur xr = rnk(x); ur xr = rnk(x);
if (xr==0) thrM("⊏: Argument cannot be rank 0"); if (xr==0) thrM("⊏: Argument cannot be rank 0");
if (a(x)->sh[0]==0) thrF("⊏: Argument shape cannot start with 0 (%H ≡ ≢𝕩)", x); if (a(x)->sh[0]==0) thrF("⊏: Argument shape cannot start with 0 (%H ≡ ≢𝕩)", x);
usz ia = 1; usz ia = shProd(a(x)->sh, 1, xr);
for (i32 i = 1; i < xr; i++) ia*= a(x)->sh[i]; Arr* r = TI(x,slice)(inc(x), 0, ia);
Arr* r = TI(x,slice)(inc(x),0, ia);
usz* sh = arr_shAlloc(r, xr-1); usz* sh = arr_shAlloc(r, xr-1);
if (sh) for (i32 i = 1; i < xr; i++) sh[i-1] = a(x)->sh[i]; if (sh) for (i32 i = 1; i < xr; i++) sh[i-1] = a(x)->sh[i];
decG(x); decG(x);

View File

@ -1132,8 +1132,7 @@ B transp_c1(B t, B x) {
usz ia = a(x)->ia; usz ia = a(x)->ia;
usz* xsh = a(x)->sh; usz* xsh = a(x)->sh;
usz h = xsh[0]; usz h = xsh[0];
usz w = xsh[1]; usz w = xsh[1] * shProd(a(x)->sh, 2, xr);
for (usz i = 2; RARE(i < xr); i++) w*= a(x)->sh[i];
Arr* r; Arr* r;
usz xi = 0; usz xi = 0;

View File

@ -1063,7 +1063,7 @@ B bitcast_impl(B el0, B el1, B x) {
sh = zsh; sh = zsh;
} }
sh[xr-1]=zl; sh[xr-1]=zl;
usz ia=zl; for (usz i=0;i<xr-1;i++)ia*=sh[i]; a(r)->ia=ia; a(r)->ia = zl*shProd(sh, 0, xr-1);
} }
return r; return r;
} }

View File

@ -35,8 +35,8 @@ B toKCells(B x, ur k) {
assert(isArr(x) && k<=rnk(x) && k>=0); assert(isArr(x) && k<=rnk(x) && k>=0);
ur xr = rnk(x); usz* xsh = a(x)->sh; ur xr = rnk(x); usz* xsh = a(x)->sh;
ur cr = xr-k; ur cr = xr-k;
usz cam = 1; for (i32 i = 0; i < k ; i++) cam*= xsh[i]; usz cam = shProd(xsh, 0, k);
usz csz = 1; for (i32 i = k; i < xr; i++) csz*= xsh[i]; usz csz = shProd(xsh, k, xr);
ShArr* csh; ShArr* csh;
if (cr>1) { if (cr>1) {

View File

@ -846,9 +846,7 @@ void g_pst(void) { vm_pstLive(); fflush(stdout); fflush(stderr); }
Arr* a = (Arr*)x; Arr* a = (Arr*)x;
if (prnk(x)<=1) assert(a->sh == &a->ia); if (prnk(x)<=1) assert(a->sh == &a->ia);
else { else {
u64 shProduct = 1; assert(shProd(a->sh, 0, prnk(x)) == a->ia);
for (usz i=0; i < prnk(x); i++) shProduct*= a->sh[i];
assert(shProduct == a->ia);
VALIDATE(tag(shObjP(x),OBJ_TAG)); VALIDATE(tag(shObjP(x),OBJ_TAG));
} }
} }