move groupstat to fancy accumulator

This commit is contained in:
dzaima 2025-03-21 01:04:34 +02:00
parent 97260e0051
commit 05d87dd7df
2 changed files with 138 additions and 94 deletions

View File

@ -10,6 +10,9 @@
# acc{'vec_result'} - get result of vector accumulates (i.e. 'to_scal' + 'scal_result')
# acc{'vec_result', n} - 'vec_result', passing n to 'to_scal'
# general things
local def extend finish{me} = {
def me{'vec_result', ...rest} = {
me{'to_scal', ...rest}
@ -17,7 +20,6 @@ local def extend finish{me} = {
}
}
local def int_els{DE, SE} = isint{DE} and isint{SE} and quality{DE}==quality{SE} and DE>=SE
def mu_extra{...args} = {
def get_all{prop} = each{{c}=>c{prop}, args}
tup{
@ -27,6 +29,28 @@ def mu_extra{...args} = {
}
}
# scalar accumulators
local def scalar_acc_impl{F, T, ident} = {
acc:T = ident
def me{'acc', v} = {
acc = F{acc, v}
}
def me{'vec_result'} = me{'scal_result'}
def me{'scal_result'} = acc
}
local def scal_bool{T} = is{T,'!'} or isunsigned{T}
local def add_promote{a:A,b} = a + promote{A, b}
def assoc_accumulator{F, '!', T if isprim{T}, ident} = scalar_acc_impl{F, T, ident}
def count_accumulator{DE, '!', T if scal_bool{T} and isunsigned{DE}} = scalar_acc_impl{add_promote, DE, 0}
def bool_accumulator {F, '!', T if scal_bool{T}, ident} = scalar_acc_impl{F, u1, ident}
def sum_accumulator {DE, '!', T if isprim{T}} = scalar_acc_impl{add_promote, DE, 0}
# vector accumulators
local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = {
if (M{0}==0) A{acc_vec, v}
else if (is{ident,0}) A{acc_vec, M{v}}
@ -57,7 +81,7 @@ local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, S, ini
def me{'from_unr'} = {
acc_vec = tree_fold{M, acc_tup}
}
def me{'to_scal', ...rest} = {
def me{'to_scal', ..._} = {
acc_scal = vfold{M, acc_vec}
}
def me{'flush_min'} = 1/0
@ -71,12 +95,16 @@ def assoc_accumulator{F if kgen{F}, unr if knum{unr}, VT=[_]E, ident} = {
me
}
def assoc_accumulator{F==min, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) E~~1/0 else maxvalue{E}}
def assoc_accumulator{F==max, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) -E~~1/0 else minvalue{E}}
def assoc_accumulator{F==__add, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, 0}
def bool_accumulator{F, unr, VT=[k]SE, ident if isunsigned{SE}} = {
def acc = assoc_accumulator{F, unr, VT, ident}
def me{...} = acc
def me{'scal_result'} = cast_i{u1, acc{'scal_result'}}
extend finish{me}
}
# aarch64-specific accumulators
def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'AARCH64'}} = {
def A = if (width{DE}>width{ux}) DE else primtype{quality{DE}, width{ux}}
def VM = el_m{VT}
@ -90,12 +118,8 @@ def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'A
def ps = each{if (exact) vfold{+,.} else fold_addw, vs}
acc_scal+= tree_fold{+, each{promote{A,.}, ps}}
}
def me{'from_unr'} = {
into_scal{slice{acc_tup, 1}}
}
def me{'to_scal', ...rest} = {
into_scal{tup{acc_vec}}
}
def me{'from_unr'} = into_scal{slice{acc_tup, 1}}
def me{'to_scal', ..._} = into_scal{tup{acc_vec}}
def me{'flush_min'} = if (exact) {
1/0
} else {
@ -112,7 +136,7 @@ def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'A
# TODO: AVX-512 could use dpbusd/dpwssd
# x86_64-specific accumulators; TODO: AVX-512 could use dpbusd/dpwssd
def sum_accumulator{DE, unr, VT=[k]SE if int_els{DE,SE} and width{SE}==8 and hasarch{'X86_64'}} = {
def VM = [k/8]u64
def VU = [k]u8
@ -167,7 +191,7 @@ def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO
def hd = promote{DE, hs}<<16
acc+= ld + hd
}
def me{'to_scal', ...rest} = {
def me{'to_scal', ..._} = {
into_scal{}
}
def me{'flush_min'} = 65536/(unr*k)
@ -183,9 +207,7 @@ def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO
}
def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8), ident if hasarch{'X86_64'} and not hasarch{'SSE4.1'}} = {
def adda{M, a, c:([16]i8)} = {
mask_ident{F,ident}{M, a, ty_u{c} ^ [16]u8**128}
}
def adda{M, a, c:([16]i8)} = mask_ident{F,ident}{M, a, ty_u{c} ^ [16]u8**128}
def {acc, ..._} = acc_impl{adda, F, unr, [16]u8, u8, (ident%256) ^ 128}
def me{...} = acc
def me{'scal_result'} = i8~~(acc{'scal_result'} ^ 128)
@ -193,3 +215,61 @@ def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8)
}
def sum_accumulator{E, unr, VT=[_]E} = assoc_accumulator{__add, unr, VT, 0}
def count_accumulator{DE, unr, VT=[k]SE if isunsigned{DE} and int_els{DE,SE}} = {
def exact = DE==SE
def widen_sum = SE==u8 and DE>u8
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{mask_ident{-,0}, +, unr, VT, DE, 0}
def into_scal{vs} = {
def curr = if (widen_sum) {
def op{v if hasarch{'X86_64'}} = absdiff_sum{8, v, VT**0}
def op{v if hasarch{'AARCH64'}} = {
assert{unr*k <= 256}
addpw{v}
}
cast_i{DE, vfold{+, tree_fold{+, each{op, vs}}}}
} else {
promote{DE, tree_fold{+, each{vfold{+,.}, vs}}}
}
acc_scal+= curr
}
def me{...} = acc
def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
def me{'from_unr'} = into_scal{slice{acc_tup, 1}}
def me{'to_scal', ..._} = into_scal{tup{acc_vec}}
def me{'flush_min'} = {
if (exact) 1/0
else if (widen_sum) maxvalue{SE}
else __floor{maxvalue{SE}/(unr*k)}
}
def me{'flush'} = if (not exact) {
into_scal{acc_tup}
acc_tup = VT**0
}
extend finish{me}
}
# def count_accumulator{DE, unr, VT=[_]SE==u8 if int_els{DE,SE} and DE>SE} = { # takes 2 instrs in core loop on x86; and still needs flushing on NEON
# def acc = sum_accumulator{u64, unr, VT}
# def me{...} = acc
# extend perv1{__neg}
# def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
# def me{'scal_result'} = cast_i{DE, acc{'scal_result'} * 0xfefefefefefefeff}
# extend finish{me}
# }
# implicit identity values
local def of_e{[_]E, G} = G{E}
local def of_e{E if isprim{E}, G} = G{E}
def assoc_accumulator{F==min, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, of_e{T, {E} => if (isfloat{E}) E~~1/0 else maxvalue{E}}}
def assoc_accumulator{F==max, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, of_e{T, {E} => if (isfloat{E}) -E~~1/0 else minvalue{E}}}
def assoc_accumulator{F==__add, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, 0}
def any_accumulator{unr, T} = bool_accumulator{|, unr, T, 0}
def all_accumulator{unr, T} = bool_accumulator{&, unr, T, if (scal_bool{T}) 1 else if (isvec{T}) maxvalue{eltype{T}} else assert{0,'bad T for all_accumulator',T} }

View File

@ -1,9 +1,7 @@
include './base'
include './vecfold'
if_inline (hasarch{'SSE2'}) {
def fold_addw{v:V} = vfold{+, fold{+, mzip128{v, V**0}}}
}
include './mask'
include './accumulator'
def vec_merge_shift_right{a:V=[n]_, b:V, s if hasarch{'SSE2'} and not hasarch{'SSSE3'}} = {
vec_shift_left{a, n-s} | vec_shift_right{b, s}
@ -14,88 +12,54 @@ def vec_merge_shift_right{a:V, b:V, 1 if width{V}>128} = {
vec_merge_shift_right_128{p, b, 1}
}
def __add{a:(usz), b:(u1)} = a + promote{usz,b}
def __lt{a:V=[_]_, b if knum{b}} = a < V**b
def __eq{a:V=[_]_, b if knum{b}} = a == V**b
def group_statistics{T} = {
def store{p:(*u8), 0, b:(u1)} = store{p, 0, promote{u8, b}}
def widen_sum = width{T} <= 8
def sum_vec = if (widen_sum) fold_addw else vfold{+, .}
def var{op, get} = {
# Identity, type
def id = match (op) { {(max)} => -1; {(&)} => 1; {_} => 0 }
def S = match (op) { {(max)} => T; {(+)} => usz; {_} => u1 }
# Scalar accumulator
def updater{v,op}{...a} = { v = op{v, get{...a}} }
def scal{val} = {
v:S = val
tup{v, updater{v, op}}
}
def scal{} = scal{id}
# Vector accumulator
def vec{l} = {
def V = match (S) { {(T)} => [l]T; {_} => [l]ty_u{T} }
v := V**(if (id==1) maxvalue{ty_u{T}} else id)
def u = updater{v, if (same{op,+}) (-) else op}
def {flush, get} = if (S!=usz) {
def get = match (op) {
{(&)} => all_hom
{(|)} => any_hom
{(max)} => vfold{max, .}
}
tup{{}=>{}, {} => get{v}}
} else {
f:usz = 0
def flush{} = { f += cast_i{usz, sum_vec{v}}; v = V**0 }
tup{flush, {} => f}
def usz_accumulator = count_accumulator{usz, ...}
def max_accumulator = assoc_accumulator{max, ..., -1}
def {types, acc_gen, ops} = each{tup,
tup{u8, any_accumulator, {_,w} => w < -1}, # bad
tup{usz, usz_accumulator, {_,w} => w == -1}, # neg
tup{u8, all_accumulator, {p,w} => p <= w }, # sort
tup{usz, usz_accumulator, {p,w} => p != w }, # change
tup{T, max_accumulator, {_,w} => w } # max
}
fn group_statistics{T}(w:*void, xn:usz, outs:each{__pnt,types}) : void = {
def w = *T~~w
def accs = if (has_simd) {
def bulk = arch_defvw/width{T}
def V = [bulk]T
def VU = ty_u{V}
def unr = 2
def accs = each{{a,T} => a{unr, if (quality{T}=='u') VU else V}, acc_gen, types}
prev_v:V = V ** -1
@for_mu{bulk, unr, mu_extra{...accs}}(curr_vs in tup{V,w}, M in 'm' over xn) {
def prev_vs = shiftright{tup{prev_v}, curr_vs}
def prev_es = each{vec_merge_shift_right{..., 1}, prev_vs, curr_vs}
each{{a, F} => {
a{'acc', M, each{F, prev_es, curr_vs}}
}, accs, ops}
prev_v = select{curr_vs,-1}
}
tup{u, flush, get}
}
tup{if (S==u1) u8 else S, scal, vec}
}
def {types, init_scal, init_vec} = each{tup,
var{|, {_,w} => w < -1}, # bad
var{+, {_,w} => w == -1}, # neg
var{&, {p,w} => p <= w }, # sort
var{+, {p,w} => p != w }, # change
var{max, {_,w} => w } # max
}
def run{g, ...par} = g{...par}
def runvars{gens, ...par} = each{run{., ...par}, gens}
fn group_statistics(w0:*void, xn:usz, outs:each{__pnt,types}) : void = {
def {start, init} = if (not has_simd) tup{0, tup{}} else {
def vl = arch_defvw/width{T}; def V = [vl]T
def {accum, flush, get} = flip{runvars{init_vec, vl}}
e:= xn / vl
i:usz = 0
prev:V = V**(-1)
while (i < e) {
def lmax = 1 << (width{T}-1 - (not widen_sum)*lb{vl})
l:= min{usz~~lmax, e-i}
@for (w in *V~~w0 + i over l) {
runvars{accum, vec_merge_shift_right{prev, w, 1}, w}
prev = w
}
i+= l
runvars{flush}
accs
} else {
p:T = -1
def accs = each{{a,T} => a{'!', T}, acc_gen, types}
@for (c in w over xn) {
each{{a, F} => a{'acc', F{p, c}}, accs, ops}
p = c
}
tup{e*vl, tup{runvars{get}}}
accs
}
def {vals, accum} = flip{each{run, init_scal, ...init}}
prev:T = -1
if (start > 0) prev = load{*T~~w0, start-1}
@for (w in *T~~w0 over _ from start to xn) {
runvars{accum, prev,w}
prev = w
}
each{store{.,0,.}, outs, vals}
def results = each{{a} => a{'vec_result'}, accs}
each{{out:*T, r} => store{out, 0, promote{T,r}}, outs, results}
}
group_statistics{T}
}
export{'si_group_statistics_i8', group_statistics{i8}}