diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index 67586d6c..6ce2461a 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -1126,21 +1126,109 @@ B reverse_c1(B t, B x) { } return withFill(mut_fcd(r, x), xf); } + + +B reverse_c2(B t, B w, B x); + +#define WRAP_ROT(V, L) ({ i64 v_ = (V); usz l_ = (L); if ((u64)v_ >= (u64)l_) { v_%= (i64)l_; if(v_<0) v_+= l_; } v_; }) +static NOINLINE B rotate_highrank(bool inv, B w, B x) { + #define INV (inv? "⁼" : "") + if (RNK(w)>1) thrF("⌽%U: 𝕨 must have rank at most 1 (%H ≡ ≢𝕨)", INV, w); + B r; + usz wia = IA(w); + if (isAtm(x) || RNK(x)==0) { + if (wia!=0) goto badlen; + r = isAtm(x)? m_unit(x) : x; + goto decW_ret; + } + + ur xr = RNK(x); + if (wia==1) { + lastaxis:; + f64 wf = o2f(IGetU(w, 0)); + r = C2(reverse, m_f64(inv? -wf : wf), x); + goto decW_ret; + } + if (wia>xr) goto badlen; + if (wia==0 || IA(x)==0) { r=x; goto decW_ret; } + + usz* xsh = SH(x); + SGetU(w) + ur cr = wia-1; + usz rot0, l0; + usz csz = 1; + while (1) { + usz xshc = xsh[cr]; + if (cr==0) goto lastaxis; + i64 wv = WRAP_ROT(o2i64(GetU(w, cr)), xshc); + if (wv!=0) { rot0 = inv? xshc-wv : wv; l0 = xshc; break; } + csz*= xshc; + cr--; + } + NOUNROLL for (usz i = xr; i-->wia; ) csz*= xsh[i]; + + TALLOC(usz, tmp, cr*3); + usz* pos = tmp+cr*0; // current position + usz* rot = tmp+cr*1; // (≠𝕩)|𝕨 + usz* xcv = tmp+cr*2; // sizes to skip by in x + + usz ri=0, xi=0; // current index in r & x + usz rSkip = csz*l0; + usz ccsz = rSkip; + for (usz i = cr; i-->0; ) { + usz xshc = xsh[i]; + i64 v = WRAP_ROT(o2i64(GetU(w, i)), xshc); + if (inv && v!=0) v = xshc-v; + pos[i] = rot[i] = v; + xi+= v*ccsz; + xcv[i] = ccsz; + ccsz*= xshc; + } + + MAKE_MUT(rm, IA(x)); mut_init(rm, TI(x,elType)); + MUTG_INIT(rm); + + usz n0 = csz*rot0; + usz n1 = csz*(l0-rot0); + while (true) { + mut_copyG(rm, ri+n1, x, xi, n0); + mut_copyG(rm, ri, x, xi+n0, n1); + usz c = cr-1; + while (true) { + if (xsh[c] == ++pos[c]) { xi-=xcv[c]*pos[c]; pos[c]=0; } + xi+= xcv[c]; + if (pos[c]!=rot[c]) break; + if (c==0) goto endCopy; + c--; + } + ri+= rSkip; + } + endCopy:; + + TFREE(tmp); + B xf = getFillE(x); + r = withFill(mut_fcd(rm, x), xf); + + decW_ret: decG(w); + return r; + badlen: thrF("⌽%U: Length of list 𝕨 must be at most rank of 𝕩 (%s ≡ ≠𝕨, %H ≡ ≢𝕩⟩", INV, wia, x); + #undef INV +} B reverse_c2(B t, B w, B x) { - if (isArr(w)) return c2rt(reverse, w, x); + if (isArr(w)) return rotate_highrank(0, w, x); if (isAtm(x) || RNK(x)==0) thrM("⌽: 𝕩 must have rank at least 1 for atom 𝕨"); usz xia = IA(x); if (xia==0) return x; - B xf = getFillQ(x); usz cam = SH(x)[0]; usz csz = arr_csz(x); - i64 am = o2i64(w); - if ((u64)am >= (u64)cam) { am%= (i64)cam; if(am<0) am+= cam; } + i64 am = WRAP_ROT(o2i64(w), cam); + if (am==0) return x; am*= csz; MAKE_MUT(r, xia); mut_init(r, TI(x,elType)); MUTG_INIT(r); mut_copyG(r, 0, x, am, xia-am); mut_copyG(r, xia-am, x, 0, am); + B xf = getFillQ(x); return withFill(mut_fcd(r, x), xf); } @@ -1368,11 +1456,16 @@ B select_ucw(B t, B o, B w, B x); B transp_uc1(B t, B o, B x) { return transp_im(t, c1(o, transp_c1(t, x))); } B reverse_uc1(B t, B o, B x) { return reverse_c1(t, c1(o, reverse_c1(t, x))); } -B reverse_ucw(B t, B o, B w, B x) { - B r = c1(o, reverse_c2(t, inc(w), x)); - return reverse_c2(t, c1(bi_sub, w), r); // above reverse_c2 call asserts the -𝕨 is fine + +B reverse_ix(B t, B w, B x) { + if (isAtm(x) || RNK(x)==0) thrM("⌽⁼: 𝕩 must have rank at least 1"); + if (isF64(w)) return reverse_c2(t, m_f64(-o2fG(w)), x); + if (isAtm(w)) thrM("⌽⁼: 𝕨 must consist of integers"); + return rotate_highrank(1, w, x); } +B reverse_ucw(B t, B o, B w, B x) { return reverse_ix(t, w, c1(o, reverse_c2(t, inc(w), x))); } + NOINLINE B enclose_im(B t, B x) { if (isAtm(x) || RNK(x)!=0) thrM("<⁼: Argument wasn't a rank 0 array"); B r = IGet(x, 0); @@ -1385,6 +1478,8 @@ B enclose_uc1(B t, B o, B x) { void sfns_init() { c(BFn,bi_pick)->uc1 = pick_uc1; + c(BFn,bi_reverse)->im = reverse_c1; + c(BFn,bi_reverse)->ix = reverse_ix; c(BFn,bi_reverse)->uc1 = reverse_uc1; c(BFn,bi_reverse)->ucw = reverse_ucw; c(BFn,bi_pick)->ucw = pick_ucw;