From a8b036ad080d93241962ae29715d51b66d45b53c Mon Sep 17 00:00:00 2001 From: Marshall Lochbaum Date: Mon, 17 Jun 2024 09:18:06 -0400 Subject: [PATCH] =?UTF-8?q?Implement=20=E2=89=A0=CB=9D=CB=98=20and=20=3D?= =?UTF-8?q?=CB=9D=CB=98=20like=20=E2=88=A7=E2=88=A8=20on=20width<64=20bool?= =?UTF-8?q?ean?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/builtins/fold.c | 1 - src/singeli/src/fold.singeli | 196 ++++++++++++++++++++++------------- 2 files changed, 122 insertions(+), 75 deletions(-) diff --git a/src/builtins/fold.c b/src/builtins/fold.c index 9d1617f3..8a41e8d1 100644 --- a/src/builtins/fold.c +++ b/src/builtins/fold.c @@ -527,7 +527,6 @@ B fold_rows_bit(Md1D* fd, B x) { decG(x); if (m==1) return y; x = y; } - if (!andor && m <= 64) return bi_N; u64* xp = bitarr_ptr(x); u64* rp; B r = m_bitarrv(&rp, n); if (andor) si_or_rows_bit(xp, rp, n, m, rtid==n_and); diff --git a/src/singeli/src/fold.singeli b/src/singeli/src/fold.singeli index 259bcb11..110e1b0f 100644 --- a/src/singeli/src/fold.singeli +++ b/src/singeli/src/fold.singeli @@ -2,8 +2,8 @@ include './base' include './mask' if_inline (hasarch{'BMI2'}) include './bmi2' include './spaced' - include 'util/tup' +include './scan_common' def opsh64{op}{v:([4]f64), perm} = op{v, shuf{[4]u64, v, perm}} def opsh32{op}{v:([2]f64), perm} = op{v, shuf{[4]u32, v, perm}} @@ -76,63 +76,17 @@ fn fold_assoc_0{T==f64, op if has_simd}(x:*T, len:u64) : T = { export{'si_sum_f64', fold_assoc_0{f64,+}} -fn xor_words(init:u64, x:*u64, l:usz):u64 = { - @for (x over l) init ^= x - init -} -def bit_output{rp:*T} = { - buf:u64 = 0 # Buffer for result bits - def output{i, bit, mod} = { - buf = buf>>1 | promote{u64, bit}<<63 - if ((i+1)%64==0) { store{rp, 0, mod{buf}}; ++rp } - } - def fixbuf{mod} = { buf = mod{buf} } - def flush_bits{n, mod} = { - q:=(-n)%64; if (q!=0) store{rp, 0, mod{buf} >> q} - } - def flush_bits{n} = flush_bits{n, {b}=>b} - tup{output, fixbuf, flush_bits} -} -# word and alignment of start of next row -def next_start{i, m} = { - bn := promote{u64, i+1} * promote{u64, m} - tup{bn/64, bn%64} -} -fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, m:usz, eq:u1) : void = { - def p64 = promote{u64, .} - def fixout = ^{-(p64{eq} &~ p64{m}), .} # ne to eq conversion - def {add_bit, _, flush_bits} = bit_output{rp} - def add_bit{i, bit} = add_bit{i, bit, fixout} - def xor_loop{len} = { - o:u64 = 0 # Carry - j:u64 = 0; @for (i to n) { - def {jn, sh} = next_start{i, m} - s := xor_words(o, xp + j, len) - e := load{xp, jn} - s ^= e & (if (not same{l,1}) -p64{jn >= j + p64{l}} else j - jn) - o = e >> sh - add_bit{i, popc{s ^ o}} - j = jn+1 - } - } - l := m/64 - if (m < 128) xor_loop{1} - else if (m%64==0) { - @for (i to n) add_bit{i, popc{xor_words(0, xp+l*i, l)}} - } - else xor_loop{l} - flush_bits{n, fixout} -} - -def or_rows_bit_lt64{xp, rp, n, l, op_and} = { +def fold_rows_bit_lt64{ + op, run_loop2, run_loop4, pext_res, mult_in, xx, rx, rxs, + xp, rp, n, l +} = { nw := cdiv{n*l,64} - xx := -promote{u64, op_and} def loop_pext{T, get, b} = { - @for (x in xp, r in *T~~rp over nw) r = cast_i{T, pext{get{x}, b}} + @for (x in xp, r in *T~~rp over nw) r = cast_i{T, rxs{pext{get{x}, b}}} } if (l == 2) { t:u64 = 64w2b10 - def loop = { + run_loop2{ if (fast_BMI2{}) { {op} => loop_pext{u32, {x}=>op{x, x<<1}, t} } else { @@ -143,15 +97,14 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = { @for (x in xp, r in *T~~rp over nw) { a := op{x, x>>1} each{{m, sh} => { a &= m; a |= a>>sh }, ms, 1<> (3*16) - r = cast_i{T, a} + r = cast_i{T, rxs{a}} } } } } - if (op_and) loop{{x} => x & ((x&~t) + m)} - else loop{{x} => x | ((x| t) - m)} } else { {e0, d} := unaligned_spaced_mask_mod{l} e := e0 << (l-1) @@ -175,38 +126,41 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = { rh := *u32~~rp ri:ux = 0 if (fast_BMI2{}) { - @for (xo in xp over i to nw) { - m := e<<1 | 1 - t := e | 1<<63 - x := xo^xx - r |= pext{x | ((x|t) - m), t} << ri + c:u64 = 0 + @for (x in xp over i to nw) { + r |= pext_res{xx{x}, e, c} << ri ri += popc{e} e = e>>d | e<<(l-d) if (ri >= 32) { - store{rh, 0, cast_i{u32,r^xx}}; ++rh; + store{rh, 0, cast_i{u32,rx{r}}}; ++rh; r >>= 32; ri -= 32 } } } else { dm:= cast_i{usz, popc{e}} + e0 |= -promote{u64, l & (l-1) == 0} c:u64 = 0 def loop{...par} = { @for (xo in xp over nw) { - m := e<<1 | 1 - t := e | (1<<63) - x := xo^xx - s := x | ((x|t) - m) # Fold results - cs:= c; c = (s&~e) >> 63 + x:= xx{xo} + def {s, mod_rb} = if (length{select{par,0}} == 3) { + # fast path for l==3 + s:= op{op{x, x<<1}, op{x<<2, c}} + c = op{x>>63, x>>62} + tup{s, {rb}=>rb} + } else { + mult_in{x, e, c} + } nb:= dm + promote{usz, e>=e0} # = popc{e} def extract = match { {{...qs, q}, {...bs, b}} => (extract{qs, bs} & b) * q { {q}, {}} => (s & e) * (q << clz{e}) } rb:= extract{...par} >> (64 - nb) - r |= (rb|cs) << ri + r |= mod_rb{rb} << ri ri += promote{ux, nb} if (ri >= 32) { - store{rh, 0, cast_i{u32,r^xx}}; ++rh; + store{rh, 0, cast_i{u32,rx{r}}}; ++rh; r >>= 32; ri -= 32 } e = e>>d | e<<(l-d) @@ -229,17 +183,111 @@ def or_rows_bit_lt64{xp, rp, n, l, op_and} = { loop{tup{mult0,mult1}, tup{topk}} } else { {mult, _} := unaligned_spaced_mask_mod{l-1} + if (l==8) mult &= 1<<(7*8) - 1 loop{tup{mult}, tup{}} } } - if (ri > 0) store{rh, 0, cast_i{u32,r^xx}} + if (ri > 0) store{rh, 0, cast_i{u32,rx{r}}} + } +} + + +fn xor_words(init:u64, x:*u64, l:usz):u64 = { + @for (x over l) init ^= x + init +} +def bit_output{rp:*T} = { + buf:u64 = 0 # Buffer for result bits + def output{i, bit, mod} = { + buf = buf>>1 | promote{u64, bit}<<63 + if ((i+1)%64==0) { store{rp, 0, mod{buf}}; ++rp } + } + def fixbuf{mod} = { buf = mod{buf} } + def flush_bits{n, mod} = { + q:=(-n)%64; if (q!=0) store{rp, 0, mod{buf} >> q} + } + def flush_bits{n} = flush_bits{n, {b}=>b} + tup{output, fixbuf, flush_bits} +} +# word and alignment of start of next row +def next_start{i, m} = { + bn := promote{u64, i+1} * promote{u64, m} + tup{bn/64, bn%64} +} +fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, eq:u1) : void = { + def p64 = promote{u64, .} + rx:= -(p64{eq} &~ p64{l}) # ne to eq conversion + if (l <= 64 and not (l%8==0 and l>16)) { + def run_loop2{loop} = loop{^} + def run_loop4{m, t, loop} = loop{{x} => { x2:= x^(x<<1); x2^(x2<<2) }} + def xor_word = prefix_byshift{^, <<} + def pext_in{xo, e, c} = { + x := xor_word{xo} + rb:= x ^ (x<>63))>>(64-l) + rb + } + def pext_res{x, e, c} = pext{pext_in{x, e, c}, e} + def mult_in{x, e, c} = tup{pext_in{x, e, c}, {r}=>r} + fold_rows_bit_lt64{ + ^, run_loop2, run_loop4, pext_res, mult_in, {x}=>x, ^{rx,.}, ^{rx,.}, + xp, rp, n, l + } + } else { + def fixout = ^{rx, .} + def {add_bit, _, flush_bits} = bit_output{rp} + def add_bit{i, bit} = add_bit{i, bit, fixout} + def xor_loop{len} = { + o:u64 = 0 # Carry + j:u64 = 0; @for (i to n) { + def {jn, sh} = next_start{i, l} + s := xor_words(o, xp + j, len) + e := load{xp, jn} + s ^= e & (if (not same{len,1}) -p64{jn >= j + p64{len}} else j - jn) + o = e >> sh + add_bit{i, popc{s ^ o}} + j = jn+1 + } + } + ll := l/64 + if (l <= 64) { + bm:= u64~~2<<(l-1) - 1 + k:= l/8 + @for (i to n) add_bit{i, popc{bm & load{*u64~~(*u8~~xp + k*i)}}} + } else if (l < 128) xor_loop{1} + else if (l%64==0) { + @for (i to n) add_bit{i, popc{xor_words(0, xp+ll*i, ll)}} + } + else xor_loop{ll} + flush_bits{n, fixout} } } fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = { def {add_bit, set_out, flush_bits} = bit_output{rp} if (l < 64) { - or_rows_bit_lt64{xp, rp, n, l, op_and} + def run_loop2{loop} = if (op_and) loop{&} else loop{|} + def run_loop4{m, t, loop} = { + if (op_and) loop{{x} => x & ((x&~t) + m)} + else loop{{x} => x | ((x| t) - m)} + } + def xor_word = prefix_byshift{^, <<} + def pext_in{x, e, c} = { + m := e<<1 | 1 + t := e | 1<<63 + tup{x | ((x|t) - m), t} + } + def pext_res{x, e, c} = pext{...pext_in{x, e, c}} + def mult_in{x, e, c} = { + {s, t} := pext_in{x, e, c} + cs:= c; c = (s&~e) >> 63 + tup{s, {rb}=>rb|cs} + } + def xx = ^{-promote{u64, op_and}, .} + fold_rows_bit_lt64{ + |, run_loop2, run_loop4, pext_res, mult_in, xx, xx, {r}=>r, + xp, rp, n, l + } return{} } else if (l < 128) { c:u64 = (promote{u64, l}-1) &- op_and # a row gives 1 if its sum is >c