From e2b4dbcf205dff520ca0ad8d2126f8dc8781d3c4 Mon Sep 17 00:00:00 2001 From: dzaima Date: Fri, 21 Mar 2025 04:05:10 +0200 Subject: [PATCH] use any_hom / all_hom for (any|all)_accumulator --- src/singeli/src/accumulator.singeli | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/singeli/src/accumulator.singeli b/src/singeli/src/accumulator.singeli index c891ac8b..c3b2cdb6 100644 --- a/src/singeli/src/accumulator.singeli +++ b/src/singeli/src/accumulator.singeli @@ -57,9 +57,10 @@ local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = { else blend_hom{acc_vec, A{acc_vec, v}, M{V, 'to homogeneous bits'}} } -local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, S, init} = { - acc_scal:S = init - def vinit = VT**init +local def acc_impl{A, M, unr, VT, S, init} = acc_impl{A, M, unr, VT, init, S, init} +local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, vinit0, S, sinit} = { + acc_scal:S = sinit + def vinit = VT**vinit0 def acc_tup = @collect(unr) { acc_var:=vinit } def acc_vec = select{acc_tup, 0} # shared with acc_tup to allow single-vector accumulates before from_unr @@ -96,8 +97,11 @@ def assoc_accumulator{F if kgen{F}, unr if knum{unr}, VT=[_]E, ident} = { } def bool_accumulator{F, unr, VT=[k]SE, ident if isunsigned{SE}} = { - def acc = assoc_accumulator{F, unr, VT, ident} + def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{mask_ident{F,ident}, F, unr, VT, ident * maxvalue{SE}, u1, ident} def me{...} = acc + def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}} + def me{'to_scal', ..._ if is{F,__or}} = acc_scal|= any_hom{acc_vec} + def me{'to_scal', ..._ if is{F,__and}} = acc_scal&= all_hom{acc_vec} def me{'scal_result'} = cast_i{u1, acc{'scal_result'}} extend finish{me} } @@ -272,4 +276,4 @@ def assoc_accumulator{F==max, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T def assoc_accumulator{F==__add, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, 0} def any_accumulator{unr, T} = bool_accumulator{|, unr, T, 0} -def all_accumulator{unr, T} = bool_accumulator{&, unr, T, if (scal_bool{T}) 1 else if (isvec{T}) maxvalue{eltype{T}} else assert{0,'bad T for all_accumulator',T} } +def all_accumulator{unr, T} = bool_accumulator{&, unr, T, 1}