Implement short-row num⊏˘bool (including ⊣˝˘ ⊢˝˘) with fold code

This commit is contained in:
Marshall Lochbaum 2024-06-18 14:05:16 -04:00
parent cb1b72fbb2
commit 4b0f105a7f
2 changed files with 28 additions and 6 deletions

View File

@ -177,6 +177,7 @@ NOINLINE B leading_axis_arith(FC2 fc2, B w, B x, usz* wsh, usz* xsh, ur mr) { //
// fast special-case implementations
extern void (*const si_select_cells_bit_lt64)(uint64_t*,uint64_t*,uint32_t,uint32_t,uint32_t); // from fold.c (fold.singeli)
static NOINLINE B select_cells(usz n, B x, usz cam, usz k, bool leaf) { // n {leaf? <∘⊑; ⊏}⎉¯k x; TODO probably can share some parts with takedrop_highrank and/or call ⊏?
ur xr = RNK(x);
assert(xr>1 && k<xr);
@ -200,7 +201,13 @@ static NOINLINE B select_cells(usz n, B x, usz cam, usz k, bool leaf) { // n {le
void* rp = m_tyarrlbp(&ra, elwBitLog(xe), cam, el2t(xe));
void* xp = tyany_ptr(x);
switch(xe) {
case el_bit: for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*jump+n)); break;
case el_bit:
#if SINGELI
if (jump < 64) si_select_cells_bit_lt64(xp, rp, cam, jump, n);
else
#endif
for (usz i=0; i<cam; i++) bitp_set(rp, i, bitp_get(xp, i*jump+n));
break;
case el_i8: case el_c8: PLAINLOOP for (usz i=0; i<cam; i++) ((u8* )rp)[i] = ((u8* )xp)[i*jump+n]; break;
case el_i16: case el_c16: PLAINLOOP for (usz i=0; i<cam; i++) ((u16*)rp)[i] = ((u16*)xp)[i*jump+n]; break;
case el_i32: case el_c32: PLAINLOOP for (usz i=0; i<cam; i++) ((u32*)rp)[i] = ((u32*)xp)[i*jump+n]; break;

View File

@ -79,6 +79,7 @@ export{'si_sum_f64', fold_assoc_0{f64,+}}
# Short-row boolean folds: main challenge is bit packing
def fold_rows_bit_lt64{
op, run_loop2, run_loop4, pext_res, mult_in,
off, # mask offset for generic methods
xx, rx, # input and output xor for cases not specialized to individual functions
rxs, # output xor only, where and/or are specialized
xp, rp, n, l
@ -108,8 +109,9 @@ def fold_rows_bit_lt64{
run_loop4{m, t, {get} => loop_T{u16, {x} => extract{get{x}}}}
} else { # generic width<64
{e0, d} := unaligned_spaced_mask_mod{l}
e := e0 << (l-1) # ending bit of each row
c:u64 = 0 # carry, use depends on algorithm
el:= e0 << (l-1) # ending bit of each row
e := if (same{off,-1}) el else e0<<off # or selected bit
c:u64 = 0; c|0 # carry, use depends on algorithm (unused for select)
def {write_bits, flush_bits} = {
r:u64 = 0
rh := *u32~~rp
@ -135,7 +137,7 @@ def fold_rows_bit_lt64{
# Emulate pext with 1, 2, or 3 multiply/mask steps.
# To move size-a groups spaced at distance b together,
# the multiplier has up to b/a bits spaced by b-a.
dm:= cast_i{usz, popc{e}} # minimum output bits per word
dm:= cast_i{usz, popc{el}} # minimum output bits per word
dm-= promote{usz, l&(l-1) == 0} # for divisors of 64, e0 effectively overflows; subtract 1 to correct
def loop{...par} = {
@for (xo in xp over nw) {
@ -183,6 +185,18 @@ def fold_rows_bit_lt64{
}
}
fn select_rows_bit_lt64(xp:*u64, rp:*u64, n:usz, l:usz, o:usz) : void = {
assert{l < 64}; assert{o < l} # Row length, and offset within row
def run_loop2{loop} = loop{{a,b} => a>>o}
def run_loop4{m, t, loop} = loop{{x} => x<<(l-1-o)}
def pext_res{x, e, c} = pext{x, e}
def mult_in{x, e, c} = tup{x, {r}=>r}
def id{x} = x
fold_rows_bit_lt64{
{a,b}=>a, run_loop2, run_loop4, pext_res, mult_in, o, id, id, id,
xp, rp, n, l
}
}
fn xor_words(init:u64, x:*u64, l:usz):u64 = {
@for (x over l) init ^= x
@ -222,7 +236,7 @@ fn xor_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, eq:u1) : void = {
def pext_res{x, e, c} = pext{pext_in{x, e, c}, e}
def mult_in{x, e, c} = tup{pext_in{x, e, c}, {r}=>r}
fold_rows_bit_lt64{
^, run_loop2, run_loop4, pext_res, mult_in, {x}=>x, ^{rx,.}, ^{rx,.},
^, run_loop2, run_loop4, pext_res, mult_in, -1, {x}=>x, ^{rx,.}, ^{rx,.},
xp, rp, n, l
}
} else {
@ -277,7 +291,7 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
}
def xx = ^{-promote{u64, op_and}, .}
fold_rows_bit_lt64{
|, run_loop2, run_loop4, pext_res, mult_in, xx, xx, {r}=>r,
|, run_loop2, run_loop4, pext_res, mult_in, -1, xx, xx, {r}=>r,
xp, rp, n, l
}
return{}
@ -316,3 +330,4 @@ fn or_rows_bit(xp:*u64, rp:*u64, n:usz, l:usz, op_and:u1) : void = {
}
export{'si_xor_rows_bit', xor_rows_bit}
export{'si_or_rows_bit', or_rows_bit}
export{'si_select_cells_bit_lt64', select_rows_bit_lt64}