use any_hom / all_hom for (any|all)_accumulator

This commit is contained in:
dzaima 2025-03-21 04:05:10 +02:00
parent 05d87dd7df
commit e2b4dbcf20

View File

@ -57,9 +57,10 @@ local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = {
else blend_hom{acc_vec, A{acc_vec, v}, M{V, 'to homogeneous bits'}}
}
local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, S, init} = {
acc_scal:S = init
def vinit = VT**init
local def acc_impl{A, M, unr, VT, S, init} = acc_impl{A, M, unr, VT, init, S, init}
local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, vinit0, S, sinit} = {
acc_scal:S = sinit
def vinit = VT**vinit0
def acc_tup = @collect(unr) { acc_var:=vinit }
def acc_vec = select{acc_tup, 0} # shared with acc_tup to allow single-vector accumulates before from_unr
@ -96,8 +97,11 @@ def assoc_accumulator{F if kgen{F}, unr if knum{unr}, VT=[_]E, ident} = {
}
def bool_accumulator{F, unr, VT=[k]SE, ident if isunsigned{SE}} = {
def acc = assoc_accumulator{F, unr, VT, ident}
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{mask_ident{F,ident}, F, unr, VT, ident * maxvalue{SE}, u1, ident}
def me{...} = acc
def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
def me{'to_scal', ..._ if is{F,__or}} = acc_scal|= any_hom{acc_vec}
def me{'to_scal', ..._ if is{F,__and}} = acc_scal&= all_hom{acc_vec}
def me{'scal_result'} = cast_i{u1, acc{'scal_result'}}
extend finish{me}
}
@ -272,4 +276,4 @@ def assoc_accumulator{F==max, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T
def assoc_accumulator{F==__add, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, 0}
def any_accumulator{unr, T} = bool_accumulator{|, unr, T, 0}
def all_accumulator{unr, T} = bool_accumulator{&, unr, T, if (scal_bool{T}) 1 else if (isvec{T}) maxvalue{eltype{T}} else assert{0,'bad T for all_accumulator',T} }
def all_accumulator{unr, T} = bool_accumulator{&, unr, T, 1}