fix ⊑˘ on rank>2 inputs

This commit is contained in:
dzaima 2025-05-02 01:11:38 +03:00
parent e32d41eb61
commit 13906efe44
4 changed files with 48 additions and 30 deletions

View File

@ -234,25 +234,33 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
// fast special-case implementations
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf); // from select.c
static NOINLINE B select_cells(usz ind, B x, usz cam, usz k, bool leaf) { // ind {leaf? <∘⊑; ⊏}⎉¯k x
ur xr = RNK(x);
assert(xr>1 && k<xr);
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz); // from select.c
static NOINLINE B select_cells(usz ind, B x, ur xr, usz cam, usz k) { // ind ⊏⎉¯k x
assert(xr == RNK(x) && xr>1 && k<xr);
usz* xsh = SH(x);
usz csz = shProd(xsh, k+1, xr);
usz l = xsh[k];
assert(0<=ind && ind<l);
assert(cam*l*csz == IA(x));
B r = select_cells_single(ind, x, cam, l, csz, leaf);
Arr* ra = a(r);
usz* rsh = arr_shAlloc(ra, leaf? k : xr-1);
assert(ind < l && cam*l*csz == IA(x));
B r = select_cells_single(ind, x, cam, l, csz);
usz* rsh = arr_shAlloc(a(r), xr-1);
if (rsh) {
shcpy(rsh, xsh, k);
if (!leaf) shcpy(rsh+k, xsh+k+1, xr-1-k);
shcpy(rsh+k, xsh+k+1, xr-1-k);
}
decG(x);
return r;
}
static NOINLINE B pick_cells(usz ind, B x, ur xr, usz cam, usz k) { // ind <∘⊑⎉¯k x
assert(xr == RNK(x) && xr>1 && k<=xr);
usz* xsh = SH(x);
usz l = shProd(xsh, k, xr);
assert(ind < (k==xr? 1 : xsh[k]) && cam*l == IA(x));
B r = select_cells_single(ind, x, cam, l, 1);
usz* rsh = arr_shAlloc(a(r), k);
if (rsh) shcpy(rsh, xsh, k);
decG(x);
return r;
}
static void set_column_typed(void* rp, B v, u8 e, ux p, ux stride, ux n) { // may write to all elements 0 ≤ i < stride×n, and after that too for masked stores
assert(p < stride);
@ -441,11 +449,11 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x; array x,
case n_select:
if (IA(x)==0) goto noSpecial;
if (cr==0) goto base;
return select_cells(0, x, cam, k, false);
return select_cells(0, x, xr, cam, k);
case n_pick:
if (IA(x)==0) goto noSpecial;
if (cr==0 || !TI(x,arrD1)) goto base;
return select_cells(0, x, cam, k, true);
if (!TI(x,arrD1)) goto base;
return pick_cells(0, x, xr, cam, k);
case n_couple: {
Arr* r = cpyWithShape(x); xsh=PSH(r);
if (xr==UR_MAX) thrF("≍%U 𝕩: Result rank too large (%i≡=𝕩)", chr==U'˘'? "˘" : "⎉𝕘", xr);
@ -524,8 +532,8 @@ B for_cells_c1(B f, u32 xr, u32 cr, u32 k, B x, u32 chr) { // F⎉cr x; array x,
usz m = xsh[k];
if (m==0) return insert_cells_identity(x, fd->f, xsh, xr, k, rtid);
if (TI(x,elType)==el_B) break;
if (m==1 || frtid==n_ltack) return select_cells(0 , x, cam, k, false);
if ( frtid==n_rtack) return select_cells(m-1, x, cam, k, false);
if (m==1 || frtid==n_ltack) return select_cells(0 , x, xr, cam, k);
if ( frtid==n_rtack) return select_cells(m-1, x, xr, cam, k);
if (isPervasiveDyExt(fd->f) && 1==shProd(xsh, k+1, xr)) {
B r;
// special cases always return rank 1
@ -729,7 +737,7 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr
if (wr == 0) {
usz ind = WRAP(o2i64(IGetU(w,0)), xsh[xk], break);
decG(w);
return select_cells(ind, x, cam, xk, false);
return select_cells(ind, x, xr, cam, xk);
}
ur rr = xk+wr;
ShArr* rsh = m_shArr(rr);
@ -740,7 +748,7 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr
}
if (isF64(w) && xcr>=1) {
usz l = xsh[xk];
return select_cells(WRAP(o2i64(w), l, thrF("𝕨⊏𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, false);
return select_cells(WRAP(o2i64(w), l, thrF("𝕨⊏𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, xr, cam, xk);
}
break;
case n_couple: if (RNK(x)==1) {
@ -750,7 +758,7 @@ NOINLINE B for_cells_SA(B f, B w, B x, ur xcr, ur xr, u32 chr) { // w⊸F⎉xcr
} break;
case n_pick: if (isF64(w) && xcr==1 && TI(x,arrD1)) {
usz l = xsh[xk];
return select_cells(WRAP(o2i64(w), l, thrF("𝕨⊑𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, cam, xk, true);
return pick_cells(WRAP(o2i64(w), l, thrF("𝕨⊑𝕩: Indexing out-of-bounds (𝕨≡%R, %s≡≠𝕩)", w, l)), x, xr, cam, xk);
} break;
case n_shifta: case n_shiftb: if (isAtm(w)) {
if (IA(x)==0) return x;

View File

@ -616,30 +616,29 @@ B select_cells_base(B inds, B x0, ux csz, ux cam);
extern void (*const si_select_cells_bit_lt64)(u64*,u64*,usz,usz,usz); // from fold.c (fold.singeli)
extern usz (*const si_select_cells_byte)(void*,void*,usz,usz,u8);
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz, bool leaf) { // ⥊ ind {leaf? <∘⊑; ⊏}˘ cam‿l‿csz ⥊ x
usz take = leaf? 1 : csz;
B select_cells_single(usz ind, B x, usz cam, usz l, usz csz) { // ⥊ ind ⊏˘ cam‿l‿csz ⥊ x
Arr* ra;
if (l==1 && take==csz) {
if (l==1) {
ra = cpyWithShape(incG(x));
arr_shErase(ra, 1);
} else {
u8 xe = TI(x,elType);
u8 ewl= elwBitLog(xe);
u8 xl = leaf? ewl : multWidthLog(csz, ewl);
usz ria = cam*take;
u8 xl = multWidthLog(csz, ewl);
usz ria = cam*csz;
if (xl>=7 || (xl<3 && xl>0)) { // generic case
MAKE_MUT_INIT(rm, ria, TI(x,elType)); MUTG_INIT(rm);
usz jump = l * csz;
usz xi = take*ind;
usz xi = csz*ind;
usz ri = 0;
for (usz i = 0; i < cam; i++) {
mut_copyG(rm, ri, x, xi, take);
mut_copyG(rm, ri, x, xi, csz);
xi+= jump;
ri+= take;
ri+= csz;
}
ra = mut_fp(rm);
} else if (xe==el_B) {
assert(take == 1);
assert(csz == 1);
SGet(x)
HArr_p rp = m_harrUv(ria);
for (usz i = 0; i < cam; i++) rp.a[i] = Get(x, i*l+ind);
@ -948,7 +947,7 @@ B select_rows_B(B x, ux csz, ux cam, B inds) { // consumes inds,x; ⥊ inds⊸
if (in == 0) return taga(emptyArr(x, 1));
if (in == 1) {
B w = IGetU(inds,0); if (!isF64(w)) goto generic;
B r = select_cells_single(WRAP_SELECT_ONE(o2i64(w), csz, "%R", w), x, cam, csz, 1, false);
B r = select_cells_single(WRAP_SELECT_ONE(o2i64(w), csz, "%R", w), x, cam, csz, 1);
decG(x); decG(inds); return r;
}
u8 ie = TI(inds,elType);

View File

@ -605,7 +605,7 @@ NOINLINE B takedrop_highrank(bool take, B w, B x) {
if (ri!=pri) mut_fillG(rm, pri, xf, ri-pri);
pri = ri+cellWrite;
}
mut_copyG(rm, ri, x, xi, cellWrite);
mut_copyG(rm, ri, x, xi, cellWrite); // TODO could use cf_
usz cr = cellStart-1;
if (0 == --lcv[cr]) {
do {

View File

@ -41,7 +41,9 @@
!"𝕨⊏𝕩: 𝕩 cannot be a unit" % 0˘5<"a"
!"𝕨⊏𝕩: 𝕩 cannot be a unit" % (30)˘3
!"Expected integer, got 0.1" % 0.1˘3515
!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 5˘234
!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 5 ˘ 234
!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 5 ˘ 2342
!"𝕨⊑𝕩: 𝕩 must be a list when 𝕨 is a number (3‿4 ≡ ≢𝕩)" % 0 2 2342
!">𝕩: Result rank too large (80 ≡ =𝕩, 205 ≡ =⊑𝕩)" % >80 (2001)<(2051)1
!"𝔽⎉𝕘: Result rank too large (195 ≡ =𝕩, 210 ≡ =𝔽v)" % >5 (2001)<(2051)1
!"𝕨∾𝕩: Lengths not matchable (⟨6⟩ ≡ ≢𝕨, 1‿1 ≡ ≢𝕩)" % ("abc""def")˘(3/)"a"
@ -82,6 +84,15 @@
%USE tcc _tcc_ ¯1 428 1_tcc_ ¯1 428 _tcc_ ¯1 42 1_tcc_ ¯1 42
%USE tcc _tcc_ ¯1 428 1_tcc_ ¯1 428 _tcc_ ¯1 42 1_tcc_ ¯1 42
1˘ 02342 %% 0340 %!HEAPVERIFY
2˘ 02222 %% 0
¯1 23424 %% 23×4×2
¯2 23424 %% 234×6
¯3 23424 %% 23424
01 23424 %% 234×6
11 23424 %% 1+234×6
0 1010100 %% 1010100
%USE eqvar 0000 {𝕨˘𝕩}_eqvar ˘5 %% 544÷˜20
%USE eqvar 0¯10¯1 {𝕨˘𝕩}_eqvar ˘5 %% 544÷˜20
%USE eqvar 100¯3 {𝕨˘𝕩}_eqvar 102002000 %% 1021001973003975005977007979009971100119713001397150015971700179719001997