fancy Singeli accumulator; use for integer sum

This commit is contained in:
dzaima 2025-03-20 21:32:11 +02:00
parent 3b1239d499
commit e1b62b95b6
3 changed files with 224 additions and 4 deletions

View File

@ -44,11 +44,23 @@
#include "../utils/mut.h"
#include "../utils/calls.h"
static const usz sum_small_max = 1<<16;
#if SINGELI
extern uint64_t* const si_spaced_masks;
#define get_spaced_mask(i) si_spaced_masks[i-1]
#define SINGELI_FILE fold
#include "../utils/includeSingeli.h"
#else
#define SUM_SMALL(T,W) \
static i64 sum_small_##T(void* xv, usz ia) { \
W s=0; \
for (usz i=0; i<ia; i++) s+=((T*)xv)[i]; \
return s; \
}
SUM_SMALL(i8 ,i32)
SUM_SMALL(i16,i32)
SUM_SMALL(i32,i64)
#undef SUM_SMALL
#endif
static u64 xor_words(u64* x, u64 l) {
@ -72,11 +84,7 @@ static i64 bit_diff(u64* x, u64 am) {
// It's safe to sum a block of integers as long as the current total
// is far enough from +-1ull<<53 (and integer, in dyadic fold).
static const usz sum_small_max = 1<<16;
#define DEF_INT_SUM(T,W,M,A) \
static i64 sum_small_##T(void* xv, usz ia) { \
i##A s=0; for (usz i=0; i<ia; i++) s+=((T*)xv)[i]; return s; \
} \
static f64 sum_##T(void* xv, usz ia, f64 init) { \
usz b=1<<(M-W); i64 lim = (1ull<<53) - (1ull<<M); \
T* xp = xv; \

View File

@ -0,0 +1,185 @@
# accumulator operations:
# acc{'flush'} - flush vector accumulator (both unrolled and non-unrolled)
# acc{'flush_min'} - how many 'acc' calls can be done between flushes for worst-case arguments
# acc{'acc', val} - add value(s) to the accumulator
# acc{'acc', M, val} - add value(s) to the accumulator based on the mask
# acc{'from_unr'} - must be invoked when transitioning from accumulating ≥2-elt tuples (1-elt tuples / plain vectors are fine any time before 'to_scal')
# acc{'to_scal'} - must be called when transitioning from vectors to scalar accumulates or taking 'scal_result'
# acc{'to_scal', n} - expanded 'to_scal', required for some ops; n must be the number of elements accumulated
# acc{'scal_result'} - get result of scalar accumulates
# acc{'vec_result'} - get result of vector accumulates (i.e. 'to_scal' + 'scal_result')
# acc{'vec_result', n} - 'vec_result', passing n to 'to_scal'
local def extend finish{me} = {
def me{'vec_result', ...rest} = {
me{'to_scal', ...rest}
me{'scal_result'}
}
}
local def int_els{DE, SE} = isint{DE} and isint{SE} and quality{DE}==quality{SE} and DE>=SE
def mu_extra{...args} = {
def get_all{prop} = each{{c}=>c{prop}, args}
tup{
{} => get_all{'from_unr'},
fold{min, get_all{'flush_min'}},
{} => get_all{'flush'},
}
}
local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = {
if (M{0}==0) A{acc_vec, v}
else if (is{ident,0}) A{acc_vec, M{v}}
else blend_hom{acc_vec, A{acc_vec, v}, M{V, 'to homogeneous bits'}}
}
local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, S, init} = {
acc_scal:S = init
def vinit = VT**init
def acc_tup = @collect(unr) { acc_var:=vinit }
def acc_vec = select{acc_tup, 0} # shared with acc_tup to allow single-vector accumulates before from_unr
def me{'acc', v} = {
match (v) {
{_:[_]_} => acc_vec = A{mask_none, acc_vec, v}
{{v0:_}} => acc_vec = A{mask_none, acc_vec, v0}
{{..._}} => each{{a,c} => a = A{mask_none, a,c}, acc_tup, v}
{_:T if isprim{T}} => acc_scal = A{mask_none, acc_scal, v}
}
}
def me{'acc', M if kgen{M}, v} = {
match (v) {
{{v0}} => me{'acc', M, v0}
{_ if M{0}==0} => me{'acc', v}
{_:[_]_ if M{0}==1} => acc_vec = A{M, acc_vec, v}
}
}
def me{'from_unr'} = {
acc_vec = tree_fold{M, acc_tup}
}
def me{'to_scal', ...rest} = {
acc_scal = vfold{M, acc_vec}
}
def me{'flush_min'} = 1/0
def me{'flush'} = {}
def me{'scal_result'} = acc_scal
tup{me, acc_scal, acc_tup, acc_vec}
}
def assoc_accumulator{F if kgen{F}, unr if knum{unr}, VT=[_]E, ident} = {
def {me, ..._} = acc_impl{mask_ident{F, ident}, F, unr, VT, E, ident}
extend finish{me}
me
}
def assoc_accumulator{F==min, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) E~~1/0 else maxvalue{E}}
def assoc_accumulator{F==max, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) -E~~1/0 else minvalue{E}}
def assoc_accumulator{F==__add, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, 0}
def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'AARCH64'}} = {
def A = if (width{DE}>width{ux}) DE else primtype{quality{DE}, width{ux}}
def VM = el_m{VT}
def [_]ME = VM
def exact = DE == ME
def addpwa{a:(A), x:E if isprim{E}} = a + promote{A,x}
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{mask_ident{addpwa,0}, +, unr, VM, A, 0}
def me{...} = acc
def into_scal{vs} = {
def ps = each{if (exact) vfold{+,.} else fold_addw, vs}
acc_scal+= tree_fold{+, each{promote{A,.}, ps}}
}
def me{'from_unr'} = {
into_scal{slice{acc_tup, 1}}
}
def me{'to_scal', ...rest} = {
into_scal{tup{acc_vec}}
}
def me{'flush_min'} = if (exact) {
1/0
} else {
def p{G} = __floor{G{ME} / G{SE} / 2} # divided by 2 because addpwa adds two elements per iter
if (issigned{SE}) min{p{maxvalue}, p{minvalue}} else p{maxvalue}
}
def me{'flush'} = if (not exact) {
into_scal{acc_tup}
acc_tup = VM**0
}
def me{'scal_result'} = cast_i{DE, acc_scal}
extend finish{me}
}
# TODO: AVX-512 could use dpbusd/dpwssd
def sum_accumulator{DE, unr, VT=[k]SE if int_els{DE,SE} and width{SE}==8 and hasarch{'X86_64'}} = {
def VM = [k/8]u64
def VU = [k]u8
def adda{M, a:(VM), c:(VT)} = {
def cu = if (SE==i8) ty_u{c} ^ VU**128 else c
a + absdiff_sum{8, M{cu}, VU**0}
}
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{adda, +, unr, VM, DE, 0}
def me{...} = acc
def me{'to_scal', ...rest} = {
acc_scal+= cast_i{DE, match (SE, ...rest) {
{(u8),..._} => vfold{+, acc_vec}
{(i8), n} => vfold{+, acc_vec} - n*128
}}
}
extend finish{me}
}
def sum_accumulator{DE, unr, VT=[k]SE if int_els{DE,SE} and SE==i16 and DE==i32 and hasarch{'X86_64'}} = { # TODO could probably extend to u16? and to 64-bit DE via flush
def VM = [k/2]i32
def adda{M, a:(VM), c:(VT)} = {
if (M{0}) a - mul_sum{2, c, VT~~M{VT, 'to homogeneous bits'}}
else a + mul_sum{2, c, VT**1}
}
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{adda, +, unr, VM, DE, 0}
def me{...} = acc
extend finish{me}
}
local include 'util/perv'
def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO extending to u32 should be trivial
acc:DE = 0
def {accl, _, accl_tup, accl_vec} = acc_impl{mask_ident{+,0}, +, unr, VT, SE, 0}
def {acch, _, acch_tup, acch_vec} = acc_impl{mask_ident{+,0}, +, unr, VT, SE, 0}
extend perv2{__shr}
def me{'acc', ...M, v} = match (v) {
{v:T if primtype{T} and M{0}==0} => acc += promote{DE, v}
{_} => {
accl{'acc', ...M, v}
acch{'acc', ...M, v >> 16}
}
}
def me{'from_unr'} = {
accl{'from_unr'}
acch{'from_unr'}
}
def into_scal{} = {
def ls = vfold{+, accl_vec}
def hs = vfold{+, acch_vec}
def ld = DE~~promote{u64, ty_u{ls} - ty_u{hs<<16}}
def hd = promote{DE, hs}<<16
acc+= ld + hd
}
def me{'to_scal', ...rest} = {
into_scal{}
}
def me{'flush_min'} = 65536/(unr*k)
def me{'flush'} = {
accl_vec = tree_fold{+, accl_tup}
acch_vec = tree_fold{+, acch_tup}
into_scal{}
accl_tup = VT**0
acch_tup = VT**0
}
def me{'scal_result'} = acc
extend finish{me}
}
def sum_accumulator{E, unr, VT=[_]E} = assoc_accumulator{__add, unr, VT, 0}

View File

@ -3,6 +3,7 @@ include './mask'
if_inline (hasarch{'BMI2'}) include './bmi'
include './spaced'
include './scan_common'
include './vecfold'
def opsh64{op}{v:[_](f64), ...perm} = op{v, shuf{v, ...perm}}
def mix{op, v:([4]f64) if hasarch{'AVX'}} = { def sh=opsh64{op}; sh{sh{v, 1,0}, 2,3,0,1} }
@ -514,3 +515,29 @@ export{'si_xor_rows_bit', xor_rows_bit}
export{'si_or_rows_bit', or_rows_bit}
export{'si_select_cells_bit_lt64', extract_column_bit_lt64}
export{'si_select_cells_byte', extract_column}
fn sum_small{T}(xv:*void, ia:usz) : i64 = {
xv:= *T~~xv
def A = if (width{T}<=16) i32 else i64
r:A = 0
@for (xv over ia) r += promote{A, xv}
promote{i64, r}
}
include './accumulator'
fn sum_small{T if has_simd}(xv:*void, ia:usz) : i64 = {
def A = if (width{T}<=16) i32 else i64
xv:= *T~~xv
def unr = 4
def bulk = arch_defvw / width{T}
def V = [bulk]T
def acc = sum_accumulator{A, unr, V}
@for_mu{bulk,unr,mu_extra{acc}}(x in tup{V,xv}, M in 'm' over ia) {
acc{'acc', M, x}
}
promote{i64, acc{'vec_result', ia}}
}
each{{w} => export{merge{'sum_small_i', fmtnat{w}}, sum_small{primtype{'i', w}}}, tup{8, 16, 32}}