From e6f1e04de21d23336c9326ff37a0dfe6f9380547 Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Thu, 13 Jun 2024 16:52:52 -0400 Subject: [PATCH] =?UTF-8?q?Fast=20generic=20and=20pext-based=20=E2=88=A7?= =?UTF-8?q?=CB=9D=CB=98=20and=20=E2=88=A8=CB=9D=CB=98=20on=20width<64?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/fold.c | 18 ++++- src/singeli/src/fold.singeli | 142 ++++++++++++++++++++++++++++++++--- 2 files changed, 148 insertions(+), 12 deletions(-) diff --git a/src/builtins/fold.c b/src/builtins/fold.c index 2547d828..70587410 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -14,6 +14,7 @@ #include "../core.h" #include "../builtins.h" #include "../utils/mut.h" +#include "../utils/calls.h" #if SINGELI #define SINGELI_FILE fold @@ -512,12 +513,23 @@ B fold_rows_bit(Md1D* fd, B x) { if (rtid==n_add) return sum_rows_bit(x); #if SINGELI if (rtid==n_ne|rtid==n_eq|rtid==n_or|rtid==n_and) { + bool andor = rtid==n_or|rtid==n_and; usz *sh = SH(x); usz n = sh[0]; usz m = sh[1]; - if (m <= 64) return bi_N; + if (andor && m < 256) while (m%8 == 0) { + usz f = CTZ(m|32); + m >>= f; usz c = m*n; + u64* yp; B y = m_bitarrv(&yp, c); + u8 e = el_i8 + f-3; + CmpASFn cmp = rtid==n_or ? CMP_AS_FN(ne, e) : CMP_AS_FN(eq, e); + CMP_AS_CALL(cmp, yp, bitarr_ptr(x), m_f64((rtid==n_or)-1), c); + decG(x); if (m==1) return y; + x = y; + } + if (!andor && m <= 64) return bi_N; u64* xp = bitarr_ptr(x); u64* rp; B r = m_bitarrv(&rp, n); - if (rtid==n_ne|rtid==n_eq) si_xor_rows_bit(xp, rp, n, m, rtid==n_eq); - else si_or_rows_bit(xp, rp, n, m, rtid==n_and); + if (andor) si_or_rows_bit(xp, rp, n, m, rtid==n_and); + else si_xor_rows_bit(xp, rp, n, m, rtid==n_eq); decG(x); return r; } #endif diff --git a/src/singeli/src/fold.singeli b/src/singeli/src/fold.singeli index c083a6d7..ff257bf4 100644 --- a/src/singeli/src/fold.singeli +++ b/src/singeli/src/fold.singeli @@ -1,5 +1,8 @@ include './base' include './mask' +if_inline (hasarch{'BMI2'}) include './bmi2' + +include 'util/tup' def opsh64{op}{v:([4]f64), perm} = op{v, shuf{[4]u64, v, perm}} def opsh32{op}{v:([2]f64), perm} = op{v, shuf{[4]u32, v, perm}} @@ -120,14 +123,135 @@ fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, eq:u1) : void = { flush_bits{n, fixout} } -fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = { - def p64 = promote{u64, .} +def unaligned_mask{l} = { + def d = 64 % l + def m = (~u64~~0 >> d) / ((u64~~1 << l)-1) + tup{m< loop_pext{u32, {x}=>op{x, x<<1}, t} + } else { + def all = 1<<64-1 + ms:5**u64 = reverse{slice{scan{{x,s}=>x^((x< { + def T = u32 + @for (x in xp, r in *T~~rp over nw) { + a := op{x, x>>1} + each{{m, sh} => { a &= m; a |= a>>sh }, ms, 1< { + def T = u16 + @for (x in xp, r in *T~~rp over nw) { + a := get{x} & t + a = (a * 4r2b001) & (64w0xf000) + a = (a * fold{+, 1<<(12*iota{4})}) >> (3*16) + r = cast_i{T, a} + } + } + } + } + if (op_and) loop{{x} => x & ((x&~t) + m)} + else loop{{x} => x | ((x| t) - m)} + } else { # odd row length + d:usz = 64 % l + e := ((~u64~~0 >> d) / ((u64~~1 << l)-1)) << (l-1) + r:u64 = 0 + rh := *u32~~rp + ri:ux = 0 + if (fast_BMI2{}) { + @for (xo in xp over i to nw) { + m := e<<1 | 1 + t := e | 1<<63 + x := xo^xx + r |= pext{x | ((x|t) - m), t} << ri + ri += popc{e} + e = e>>d | e<<(l-d) + if (ri >= 32) { + store{rh, 0, cast_i{u32,r^xx}}; ++rh; + r >>= 32; ri -= 32 + } + } + } else { + dm:= 64/l + e0:= e<<1 | 1 + c:u64 = 0 + def loop{...par} = { + @for (xo in xp over nw) { + m := e<<1 | 1 + t := e | (1<<63) + x := xo^xx + s := x | ((x|t) - m) # Fold results + cs:= c; c = (s&~e) >> 63 + nb:= dm + promote{usz, e>=e0} # = popc{e} + def extract = match { + {{...qs, q}, {...bs, b}} => (extract{qs, bs} & b) * q + { {q}, {}} => (s & e) * (q << clz{e}) + } + rb:= extract{...par} >> (64 - nb) + r |= (rb|cs) << ri + ri += promote{ux, nb} + if (ri >= 32) { + store{rh, 0, cast_i{u32,r^xx}}; ++rh; + r >>= 32; ri -= 32 + } + e = e>>d | e<<(l-d) + } + } + if (l == 3) { + mult0:u64 = base{1<< 2, 3**1}; top3:u64 = base{1<<9, 8**(1<<3-1)}>>2 + mult1:u64 = base{1<< 6, 3**1}; top9:u64 = base{1<<27, 3**(1<<9-1)}<<1 + mult2:u64 = base{1<<18, 3**1} + loop{tup{mult0,mult1,mult2}, tup{top3, top9}} + } else if (l < 8) { + assert{l > 4} + ld:= l-1; lld:= l*ld + {mult0, _} := unaligned_mask{ld} + mult0 &= u64~~1<>l; topk|= topk<>ll + loop{tup{mult0,mult1}, tup{topk}} + } else { + {mult, _} := unaligned_mask{l-1} + loop{tup{mult}, tup{}} + } + } + if (ri > 0) store{rh, 0, cast_i{u32,r^xx}} + } +} + +fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = { def {add_bit, set_out, flush_bits} = bit_output{rp} - if (m < 128) { - c:u64 = (p64{m}-1) &- op_and # a row gives 1 if its sum is >c + if (l < 64) { + or_rows_bit_lt64{xp, rp, n, l, op_and} + return{} + } else if (l < 128) { + c:u64 = (promote{u64, l}-1) &- op_and # a row gives 1 if its sum is >c o:u64 = 0 j:u64 = 0; @for (i to n) { - def {jn, sh} = next_start{i, m} + def {jn, sh} = next_start{i, l} s := o + popc{load{xp,j}} e := load{xp,jn} s += popc{e & (j - jn)} # mask is 0 if j==jn, or -1 @@ -140,15 +264,15 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, op_and:u1) : void = { def fixout = ^{rx, .} o:u64 = 0 # Saved bits j:u64 = 0; @for (i to n) { - def {jn, sh} = next_start{i, m} + def {jn, sh} = next_start{i, l} e := load{xp,jn} ^ rx - m := ~(u64~~0) << sh + l := ~(u64~~0) << sh rb:u64 = 1 - if ((o | (e &~ m)) == 0) { # Search for shortcut + if ((o | (e &~ l)) == 0) { # Search for shortcut @for (i from j to jn-1) if (load{xp,i} != id) goto{'found'} rb = 0; setlabel{'found'} } - o = e & m + o = e & l add_bit{i, rb, fixout} j = jn+1 }