diff --git a/src/builtins/grade.h b/src/builtins/grade.h index 5d7d634c..fd2ebeed 100644 --- a/src/builtins/grade.h +++ b/src/builtins/grade.h @@ -30,51 +30,72 @@ rp[j] = xi; \ } +#if SINGELI +extern void (*const avx2_scan_max8)(int8_t* v0,int8_t* v1,uint64_t v2); +extern void (*const avx2_scan_min8)(int8_t* v0,int8_t* v1,uint64_t v2); +extern void (*const avx2_scan_max16)(int16_t* v0,int16_t* v1,uint64_t v2); +extern void (*const avx2_scan_min16)(int16_t* v0,int16_t* v1,uint64_t v2); +#define COUNT_THRESHOLD 32 +#define WRITE_SPARSE_i8 \ + for (usz i=0; i>=), PRE_UD(K,>>,<<=)) +#define INC(P,I) GRADE_UD((P+1)[I]++,P[I]--) +#define ROFF GRADE_UD(1,0) // Radix offset #define CHOOSE_SG_SORT(S,G) S #define CHOOSE_SG_GRADE(S,G) G #define RADIX_SORT_i8(T, TYP) \ - TALLOC(T, c0, 256); T *c0o=c0+128; \ + TALLOC(T, c0, 256+ROFF); T* c0o=c0+128; \ for (usz j=0; j<256; j++) c0[j]=0; \ - for (usz i=0; i>8)]++; } \ + c1[0]=GRADE_UD(-n,c0[0]=n); \ + for (usz i=0; i>8)); } \ RADIX_SUM_2_##T; \ i16 *r0 = (i16*)(c0+2*256); \ CHOOSE_SG_##TYP( \ @@ -86,18 +107,15 @@ for (usz i=0; i>8)]++]=g0[i]; } \ ) \ TFREE(alloc) -#define RADIX_SUM_2_u8 u64 s0=0, s1=0; FOR(j,256/8) { PRE64(0); PRE64(1); } -#define RADIX_SUM_2(T) T s0=0, s1=0; FOR(j,256) { PRE(T,0); PRE(T,1); } -#define RADIX_SUM_2_usz RADIX_SUM_2(usz) -#define RADIX_SUM_2_u32 RADIX_SUM_2(u32) #define RADIX_SORT_i32(T, TYP, I) \ - TALLOC(u8, alloc, 4*256*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \ + TALLOC(u8, alloc, (4*256+ROFF)*sizeof(T) + n*(4 + CHOOSE_SG_##TYP(0,4+sizeof(I)))); \ T *c0=(T*)alloc, *c1=c0+256, *c2=c1+256, *c3=c2+256, *c3o=c3+128; \ for (usz j=0; j<4*256; j++) c0[j]=0; \ + c1[0]=c2[0]=c3[0]=GRADE_UD(-n,c0[0]=n); \ for (usz i=0; i> 8)]++; \ - c2 [(u8)(v>>16)]++; c3o[(i8)(v>>24)]++; } \ + INC(c0 ,(u8)v ); INC(c1 ,(u8)(v>> 8)); \ + INC(c2 ,(u8)(v>>16)); INC(c3o,(i8)(v>>24)); } \ RADIX_SUM_4_##T; \ i32 *r0 = (i32*)(c0+4*256); \ CHOOSE_SG_##TYP( \ @@ -113,10 +131,37 @@ for (usz i=0; i>24)]++; rp[c]=g0[i]; } \ ) \ TFREE(alloc) -#define RADIX_SUM_4_u8 u64 s0=0, s1=0, s2=0, s3=0; FOR(j, 256/8) { PRE64(0); PRE64(1); PRE64(2); PRE64(3); } -#define RADIX_SUM_4(T) T s0=0, s1=0, s2=0, s3=0; FOR(j, 256) { PRE(u32,0); PRE(u32,1); PRE(u32,2); PRE(u32,3); } -#define RADIX_SUM_4_usz RADIX_SUM_4(usz) -#define RADIX_SUM_4_u32 RADIX_SUM_4(u32) + +#define PRE(K) s##K=c##K[j]+=s##K +#define RADIX_SUM_1(T) T s0=0; for(usz j=0;j<256;j++) { PRE(0); } +#define RADIX_SUM_2(T) GRADE_UD(c1[0]=0;,) T s0=0, s1=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); } +#define RADIX_SUM_4(T) GRADE_UD(c1[0]=c2[0]=c3[0]=0;,) T s0=0, s1=0, s2=0, s3=0; for(usz j=0;j<256;j++) { PRE(0); PRE(1); PRE(2); PRE(3); } + +#if SINGELI +extern void (*const avx2_scan_pluswrap_u8)(uint8_t* v0,uint8_t* v1,uint64_t v2,uint8_t v3); +extern void (*const avx2_scan_pluswrap_u32)(uint32_t* v0,uint32_t* v1,uint64_t v2,uint32_t v3); +#define RADIX_SUM_1_u8 avx2_scan_pluswrap_u8 (c0,c0, 256,0); +#define RADIX_SUM_2_u8 avx2_scan_pluswrap_u8 (c0,c0,2*256,0); +#define RADIX_SUM_2_u32 avx2_scan_pluswrap_u32(c0,c0,2*256,0); +#define RADIX_SUM_4_u8 avx2_scan_pluswrap_u8 (c0,c0,4*256,0); +#define RADIX_SUM_4_u32 avx2_scan_pluswrap_u32(c0,c0,4*256,0); +#else +#define RADIX_SUM_1_u8 RADIX_SUM_1(u8) +#define RADIX_SUM_2_u8 RADIX_SUM_2(u8) +#define RADIX_SUM_2_u32 RADIX_SUM_2(u32) +#define RADIX_SUM_4_u8 RADIX_SUM_4(u8) +#define RADIX_SUM_4_u32 RADIX_SUM_4(u32) +#endif + +#if SINGELI && !USZ_64 +#define RADIX_SUM_1_usz avx2_scan_pluswrap_u32(c0,c0, 256,0); +#define RADIX_SUM_2_usz avx2_scan_pluswrap_u32(c0,c0,2*256,0); +#define RADIX_SUM_4_usz avx2_scan_pluswrap_u32(c0,c0,4*256,0); +#else +#define RADIX_SUM_1_usz RADIX_SUM_1(usz) +#define RADIX_SUM_2_usz RADIX_SUM_2(usz) +#define RADIX_SUM_4_usz RADIX_SUM_4(usz) +#endif #define SORT_C1 CAT(GRADE_UD(and,or),c1) B SORT_C1(B t, B x) { @@ -140,9 +185,9 @@ B SORT_C1(B t, B x) { } else if (xe==el_i8) { i8* xp = i8any_ptr(x); i8* rp; r = m_i8arrv(&rp, n); - if (n<16) { + if (n < 16) { INSERTION_SORT(i8); - } else if (n<256) { + } else if (n < 256) { RADIX_SORT_i8(u8, SORT); } else { COUNTING_SORT(i8); @@ -150,7 +195,7 @@ B SORT_C1(B t, B x) { } else if (xe==el_i16) { i16* xp = i16any_ptr(x); i16* rp; r = m_i16arrv(&rp, n); - if (n < 24) { + if (n < 20) { INSERTION_SORT(i16); } else if (n < 256) { RADIX_SORT_i16(u8, SORT,); @@ -162,7 +207,7 @@ B SORT_C1(B t, B x) { } else if (xe==el_i32) { i32* xp = i32any_ptr(x); i32* rp; r = m_i32arrv(&rp, n); - if (n < 40) { + if (n < 32) { INSERTION_SORT(i32); } else if (n < 256) { RADIX_SORT_i32(u8, SORT,); @@ -181,6 +226,10 @@ B SORT_C1(B t, B x) { #undef SORT_C1 #undef INSERTION_SORT #undef COUNTING_SORT +#if SINGELI +#undef WRITE_SPARSE_i8 +#undef WRITE_SPARSE_i16 +#endif #define GRADE_CHR GRADE_UD("ā‹","ā’") @@ -356,21 +405,23 @@ done: #undef LT #undef FOR #undef PRE -#undef PRE_UD +#undef INC +#undef ROFF #undef PRE64 #undef CHOOSE_SG_SORT #undef CHOOSE_SG_GRADE #undef RADIX_SORT_i8 +#undef RADIX_SORT_i16 +#undef RADIX_SORT_i32 +#undef RADIX_SUM_1 +#undef RADIX_SUM_2 +#undef RADIX_SUM_4 #undef RADIX_SUM_1_u8 #undef RADIX_SUM_1_usz -#undef RADIX_SORT_i16 #undef RADIX_SUM_2_u8 -#undef RADIX_SUM_2 #undef RADIX_SUM_2_usz #undef RADIX_SUM_2_u32 -#undef RADIX_SORT_i32 #undef RADIX_SUM_4_u8 -#undef RADIX_SUM_4 #undef RADIX_SUM_4_usz #undef RADIX_SUM_4_u32 #undef GRADE_CAT diff --git a/src/builtins/sfns.c b/src/builtins/sfns.c index 7d7ab988..d9640efc 100644 --- a/src/builtins/sfns.c +++ b/src/builtins/sfns.c @@ -104,6 +104,13 @@ static B truncReshape(B x, usz xia, usz nia, ur nr, ShArr* sh) { // consumes all arr_shSetU(ra, nr, sh); return r; } +static void fill_words(void* rp, u64 v, u64 bytes) { + usz wds = bytes/8; + usz ext = bytes%8; + u64* p = rp; + for (usz i=0; i> (64-b)); + do { v |= v<64 && nia>64) rq[1] = v>>(64-b/2); + } else { + memcpy(rq, xp, (b+7)/8); + } + for (; b%8; b*=2) { + if (b>nw*32) { + if (b=bf since bf is rounded up + break; + } + bit_cpy(rq, b, rq, 0, b); + } + } else { + memcpy(rp, xp, b/8); + } + bi = b/8; + bf = 8*nw; + if (bi == 1) { memset(rp, rp[0], bf); bi=bf; } + } else { + if (TI(x,elType) == el_B) { + B xf = getFillQ(x); + MAKE_MUT(m, nia); mut_init(m, el_B); + MUTG_INIT(m); + i64 div = nia/xia; + i64 mod = nia%xia; + for (i64 i = 0; i < div; i++) mut_copyG(m, i*xia, x, 0, xia); + mut_copyG(m, div*xia, x, 0, mod); + decG(x); + Arr* ra = mut_fp(m); + arr_shSetU(ra, nr, sh); + return withFill(taga(ra), xf); + } + u8 xk = xl - 3; + rp = m_tyarrp(&r, 1<> (64-b)); + while (b<64) { v |= v<bf) l=bf; + for (; bi<=l/2; bi+=bi) memcpy(rp+bi, rp, bi); + u64 e=bi; for (; e+bi<=bf; e+=bi) memcpy(rp+e, rp, bi); + if (e>63, js=-(xx&1); xx^=xx<<1; \ + for (usz k=0, j=0, ij=WV; ; ) { \ + usz e = b>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \ + rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=WV; \ + } \ + for (usz i=k/64; i>63); \ + if (e==s) {break;} k=e; \ + } + +// Basic boolean loop with overwriting +#define BOOL_REP_OVER(WV, LEN) \ + u64 ri=0, rc=0, xc=0; usz j=0; \ + for (usz i = 0; i < LEN; i++) { \ + u64 v = -(u64)bitp_get(xp,i); \ + rc ^= (v^xc) << (ri%64); \ + xc = v; \ + ri += WV; usz e = ri/64; \ + if (j < e) { \ + rp[j++] = rc; \ + while (j < e) rp[j++] = v; \ + rc = v; \ + } \ + } \ + if (ri%64) rp[j] = rc; + extern B rt_slash; B slash_c1(B t, B x) { if (RARE(isAtm(x)) || RARE(RNK(x)!=1)) thrF("/: Argument must have rank 1 (%H ≔ ā‰¢š•©)", x); @@ -465,17 +532,8 @@ B slash_c1(B t, B x) { for (u64 j = 0; j < c; j++) *rp++ = i; } } else { - if (s/16 <= xia) { // Sparse case: type of x matters - #define SPARSE_IND(T) \ - T* xp = T##any_ptr(x); \ - usz b = 1<<10; \ - for (usz k=0, j=0, js=0, ij=xp[0]; ; ) { \ - usz e = b1) goto base; + ur wr = RNK(w); + if (wr>1) thrF("/: Simple š•Ø must have rank 0 or 1 (%i≔=š•Ø)", wr); + if (wr<1) { B v=IGet(w, 0); decG(w); w=v; goto atom; } + wia = IA(w); if (wia==0) { decG(w); return isArr(x)? x : m_atomUnit(x); } - if (isAtm(x) || RNK(x)==0) thrM("/: š•© must have rank at least 1 for simple š•Ø"); - ur xr = RNK(x); - usz xlen = *SH(x); + } else { + atom: + if (!q_i32(w)) goto base; + wv = o2i(w); + } + if (isAtm(x) || RNK(x)==0) thrM("/: š•© must have rank at least 1 for simple š•Ø"); + ur xr = RNK(x); + usz xlen = *SH(x); + u8 xl = cellWidthLog(x); + u8 xt = arrNewType(TY(x)); + + B r; + if (wv < 0) { // Array w if (RARE(wia!=xlen)) thrF("/: Lengths of components of š•Ø must match š•© (%s ≠ %s)", wia, xlen); - u8 xl = cellWidthLog(x); - u8 xt = arrNewType(TY(x)); - u8 we = TI(w,elType); if (!elInt(we)) { w=any_squeeze(w); we=TI(w,elType); @@ -555,30 +625,16 @@ B slash_c2(B t, B w, B x) { // Make shape if needed; all cases below use it usz* rsh = NULL; if (xr > 1) { - usz* sh = rsh = m_shArr(xr)->a; - sh[0] = s; - shcpy(sh+1, SH(x)+1, xr-1); + rsh = m_shArr(xr)->a; + rsh[0] = s; + shcpy(rsh+1, SH(x)+1, xr-1); } if (xl == 0) { u64* xp = bitarr_ptr(x); u64* rp; r = m_bitarrv(&rp, s); if (rsh) { SPRNK(a(r),xr); SH(r) = rsh; } - if (s/256 <= wia) { - #define SPARSE_REP(T) \ - T* wp = T##any_ptr(w); \ - usz b = 1<<12; \ - u64 xx=xp[0], xs=xx>>63, js=-(xx&1); xx^=xx<<1; \ - for (usz k=0, j=0, ij=wp[0]; ; ) { \ - usz e = b>=1; j++; if (j%64==0) { u64 v=xp[j/64]; xx=v^(v<<1)^xs; xs=v>>63; } \ - rp[ij/64]^=(-(xx&1))<<(ij%64); ij+=wp[j]; \ - } \ - for (usz i=k/64; i>63); \ - if (e==s) {break;} k=e; \ - } + if (s/1024 <= wia) { + #define SPARSE_REP(T) T* wp=T##any_ptr(w); BOOL_REP_XOR_SCAN(wp[j]) if (we==el_i8 ) { SPARSE_REP(i8 ); } else if (we==el_i16) { SPARSE_REP(i16); } else { SPARSE_REP(i32); } @@ -586,37 +642,15 @@ B slash_c2(B t, B w, B x) { } else { if (we < el_i32) w = taga(cpyI32Arr(w)); i32* wp = i32any_ptr(w); - u64 ri=0, rc=0, xc=0; usz j=0; - for (usz i = 0; i < wia; i++) { - u64 v = -(u64)bitp_get(xp,i); - rc ^= (v^xc) << (ri%64); - xc = v; - ri += wp[i]; usz e = ri/64; - if (j < e) { - rp[j++] = rc; - while (j < e) rp[j++] = v; - rc = v; - } - } - if (ri%64) rp[j] = rc; + BOOL_REP_OVER(wp[i], wia) } } else { u8 xk = xl-3; void* rv = m_tyarrv(&r, 1<6 || (xl<3 && xl!=0) || TI(x,elType)==el_B) { + if (xr != 1) goto base; SLOW2("š•Ø/š•©", w, x); B xf = getFillQ(x); - HArr_p r0 = m_harrUv(xia*wv); + HArr_p r0 = m_harrUv(s); SGetU(x) - for (usz i = 0; i < xia; i++) { + for (usz i = 0; i < xlen; i++) { B cx = incBy(GetU(x, i), wv); for (i64 j = 0; j < wv; j++) *r0.a++ = cx; } r = withFill(r0.b, xf); goto decX_ret; } + if (xl == 0) { + u64* xp = bitarr_ptr(x); + u64* rp; r = m_bitarrv(&rp, s); + if (wv <= 256) { BOOL_REP_XOR_SCAN(wv) } + else { BOOL_REP_OVER(wv, xlen) } + goto decX_ret; + } else { + u8 xk = xl-3; + void* rv = m_tyarrv(&r, 1< 1) { + usz* rsh = m_shArr(xr)->a; + rsh[0] = s; + shcpy(rsh+1, SH(x)+1, xr-1); + Arr* ra=a(r); SPRNK(ra,xr); PSH(ra)=rsh; PIA(ra)=s*arr_csz(x); + } + goto decX_ret; } base: return c2(rt_slash, w, x); diff --git a/src/singeli/src/scan.singeli b/src/singeli/src/scan.singeli index d3ca6649..3719a612 100644 --- a/src/singeli/src/scan.singeli +++ b/src/singeli/src/scan.singeli @@ -10,20 +10,42 @@ def sel8{v, t & istup{t} & tuplen{t}==16} = sel8{v, merge{t,t}} def base{b,l} = { if (0==tuplen{l}) 0; else tupsel{0,l}+b*base{b,slice{l,1}} } def shuf{T, v, n & istup{n}} = shuf{T, v, base{4,n}} +# Fill last 4 bytes with last element, in each lane +def spread{a:VT} = { + def w = width{eltype{VT}} + def b = w/8 + if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a +} + +def scan_loop{T, init, x:*T, r:*T, len:u64, scan, scan_last} = { + def step = 256/width{T} + def V = [step]T + p:= broadcast{V, init} + xv:= *V ~~ x + rv:= *V ~~ r + e:= len/step + @for (xv, rv over e) rv = scan{xv,p} + q:= len & (step-1) + if (q) maskstoreF{rv, maskOf{V, q}, e, scan_last{load{xv,e}, p}} +} +def scan_post{T, init, x:*T, r:*T, len:u64, op, pre} = { + def last{v, p} = op{pre{v}, p} + def scan{v, p} = { + n:= last{v, p} + p = sel{[8]i32, spread{n}, broadcast{[8]i32, 7}} + n + } + scan_loop{T, init, x, r, len, scan, last} +} + # Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = { - def w = width{T} - # Within each lane, scan using shifts by powers of 2. First k elements # when shifting by k don't need to change, so leave them alone. + def w = width{T} def shift{k,l} = merge{iota{k},iota{l-k}} def c8 {k, a} = op{a, shuf{[4]u32, a, shift{k,4}}} def c32{k, a} = (if (w<=8*k) op{a, sel8{a, shift{k,16}}}; else a) - # Fill last 4 bytes with last element, in each lane - def spread{a} = { - def b = w/8 - if (w<=16) sel8{a,merge{iota{12},(16-b)+iota{4}%b}}; else a - } # Prefix op on entire AVX register def pre{a} = { b:= c8{2, c8{1, c32{2, c32{1, a}}}} @@ -31,19 +53,7 @@ avx2_scan_idem{T, op, id}(x:*T, r:*T, len:u64) : void = { op{b, sel{[8]i32, spread{b}, make{[8]i32, 3*(3