move groupstat to fancy accumulator

This commit is contained in:
dzaima 2025-03-21 01:04:34 +02:00
parent 97260e0051
commit 05d87dd7df
2 changed files with 138 additions and 94 deletions

View File

@ -10,6 +10,9 @@
# acc{'vec_result'} - get result of vector accumulates (i.e. 'to_scal' + 'scal_result') # acc{'vec_result'} - get result of vector accumulates (i.e. 'to_scal' + 'scal_result')
# acc{'vec_result', n} - 'vec_result', passing n to 'to_scal' # acc{'vec_result', n} - 'vec_result', passing n to 'to_scal'
# general things
local def extend finish{me} = { local def extend finish{me} = {
def me{'vec_result', ...rest} = { def me{'vec_result', ...rest} = {
me{'to_scal', ...rest} me{'to_scal', ...rest}
@ -17,7 +20,6 @@ local def extend finish{me} = {
} }
} }
local def int_els{DE, SE} = isint{DE} and isint{SE} and quality{DE}==quality{SE} and DE>=SE local def int_els{DE, SE} = isint{DE} and isint{SE} and quality{DE}==quality{SE} and DE>=SE
def mu_extra{...args} = { def mu_extra{...args} = {
def get_all{prop} = each{{c}=>c{prop}, args} def get_all{prop} = each{{c}=>c{prop}, args}
tup{ tup{
@ -27,6 +29,28 @@ def mu_extra{...args} = {
} }
} }
# scalar accumulators
local def scalar_acc_impl{F, T, ident} = {
acc:T = ident
def me{'acc', v} = {
acc = F{acc, v}
}
def me{'vec_result'} = me{'scal_result'}
def me{'scal_result'} = acc
}
local def scal_bool{T} = is{T,'!'} or isunsigned{T}
local def add_promote{a:A,b} = a + promote{A, b}
def assoc_accumulator{F, '!', T if isprim{T}, ident} = scalar_acc_impl{F, T, ident}
def count_accumulator{DE, '!', T if scal_bool{T} and isunsigned{DE}} = scalar_acc_impl{add_promote, DE, 0}
def bool_accumulator {F, '!', T if scal_bool{T}, ident} = scalar_acc_impl{F, u1, ident}
def sum_accumulator {DE, '!', T if isprim{T}} = scalar_acc_impl{add_promote, DE, 0}
# vector accumulators
local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = { local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = {
if (M{0}==0) A{acc_vec, v} if (M{0}==0) A{acc_vec, v}
else if (is{ident,0}) A{acc_vec, M{v}} else if (is{ident,0}) A{acc_vec, M{v}}
@ -57,7 +81,7 @@ local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, S, ini
def me{'from_unr'} = { def me{'from_unr'} = {
acc_vec = tree_fold{M, acc_tup} acc_vec = tree_fold{M, acc_tup}
} }
def me{'to_scal', ...rest} = { def me{'to_scal', ..._} = {
acc_scal = vfold{M, acc_vec} acc_scal = vfold{M, acc_vec}
} }
def me{'flush_min'} = 1/0 def me{'flush_min'} = 1/0
@ -71,12 +95,16 @@ def assoc_accumulator{F if kgen{F}, unr if knum{unr}, VT=[_]E, ident} = {
me me
} }
def assoc_accumulator{F==min, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) E~~1/0 else maxvalue{E}} def bool_accumulator{F, unr, VT=[k]SE, ident if isunsigned{SE}} = {
def assoc_accumulator{F==max, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) -E~~1/0 else minvalue{E}} def acc = assoc_accumulator{F, unr, VT, ident}
def assoc_accumulator{F==__add, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, 0} def me{...} = acc
def me{'scal_result'} = cast_i{u1, acc{'scal_result'}}
extend finish{me}
}
# aarch64-specific accumulators
def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'AARCH64'}} = { def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'AARCH64'}} = {
def A = if (width{DE}>width{ux}) DE else primtype{quality{DE}, width{ux}} def A = if (width{DE}>width{ux}) DE else primtype{quality{DE}, width{ux}}
def VM = el_m{VT} def VM = el_m{VT}
@ -90,12 +118,8 @@ def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'A
def ps = each{if (exact) vfold{+,.} else fold_addw, vs} def ps = each{if (exact) vfold{+,.} else fold_addw, vs}
acc_scal+= tree_fold{+, each{promote{A,.}, ps}} acc_scal+= tree_fold{+, each{promote{A,.}, ps}}
} }
def me{'from_unr'} = { def me{'from_unr'} = into_scal{slice{acc_tup, 1}}
into_scal{slice{acc_tup, 1}} def me{'to_scal', ..._} = into_scal{tup{acc_vec}}
}
def me{'to_scal', ...rest} = {
into_scal{tup{acc_vec}}
}
def me{'flush_min'} = if (exact) { def me{'flush_min'} = if (exact) {
1/0 1/0
} else { } else {
@ -112,7 +136,7 @@ def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'A
# TODO: AVX-512 could use dpbusd/dpwssd # x86_64-specific accumulators; TODO: AVX-512 could use dpbusd/dpwssd
def sum_accumulator{DE, unr, VT=[k]SE if int_els{DE,SE} and width{SE}==8 and hasarch{'X86_64'}} = { def sum_accumulator{DE, unr, VT=[k]SE if int_els{DE,SE} and width{SE}==8 and hasarch{'X86_64'}} = {
def VM = [k/8]u64 def VM = [k/8]u64
def VU = [k]u8 def VU = [k]u8
@ -167,7 +191,7 @@ def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO
def hd = promote{DE, hs}<<16 def hd = promote{DE, hs}<<16
acc+= ld + hd acc+= ld + hd
} }
def me{'to_scal', ...rest} = { def me{'to_scal', ..._} = {
into_scal{} into_scal{}
} }
def me{'flush_min'} = 65536/(unr*k) def me{'flush_min'} = 65536/(unr*k)
@ -183,9 +207,7 @@ def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO
} }
def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8), ident if hasarch{'X86_64'} and not hasarch{'SSE4.1'}} = { def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8), ident if hasarch{'X86_64'} and not hasarch{'SSE4.1'}} = {
def adda{M, a, c:([16]i8)} = { def adda{M, a, c:([16]i8)} = mask_ident{F,ident}{M, a, ty_u{c} ^ [16]u8**128}
mask_ident{F,ident}{M, a, ty_u{c} ^ [16]u8**128}
}
def {acc, ..._} = acc_impl{adda, F, unr, [16]u8, u8, (ident%256) ^ 128} def {acc, ..._} = acc_impl{adda, F, unr, [16]u8, u8, (ident%256) ^ 128}
def me{...} = acc def me{...} = acc
def me{'scal_result'} = i8~~(acc{'scal_result'} ^ 128) def me{'scal_result'} = i8~~(acc{'scal_result'} ^ 128)
@ -193,3 +215,61 @@ def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8)
} }
def sum_accumulator{E, unr, VT=[_]E} = assoc_accumulator{__add, unr, VT, 0} def sum_accumulator{E, unr, VT=[_]E} = assoc_accumulator{__add, unr, VT, 0}
def count_accumulator{DE, unr, VT=[k]SE if isunsigned{DE} and int_els{DE,SE}} = {
def exact = DE==SE
def widen_sum = SE==u8 and DE>u8
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{mask_ident{-,0}, +, unr, VT, DE, 0}
def into_scal{vs} = {
def curr = if (widen_sum) {
def op{v if hasarch{'X86_64'}} = absdiff_sum{8, v, VT**0}
def op{v if hasarch{'AARCH64'}} = {
assert{unr*k <= 256}
addpw{v}
}
cast_i{DE, vfold{+, tree_fold{+, each{op, vs}}}}
} else {
promote{DE, tree_fold{+, each{vfold{+,.}, vs}}}
}
acc_scal+= curr
}
def me{...} = acc
def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
def me{'from_unr'} = into_scal{slice{acc_tup, 1}}
def me{'to_scal', ..._} = into_scal{tup{acc_vec}}
def me{'flush_min'} = {
if (exact) 1/0
else if (widen_sum) maxvalue{SE}
else __floor{maxvalue{SE}/(unr*k)}
}
def me{'flush'} = if (not exact) {
into_scal{acc_tup}
acc_tup = VT**0
}
extend finish{me}
}
# def count_accumulator{DE, unr, VT=[_]SE==u8 if int_els{DE,SE} and DE>SE} = { # takes 2 instrs in core loop on x86; and still needs flushing on NEON
# def acc = sum_accumulator{u64, unr, VT}
# def me{...} = acc
# extend perv1{__neg}
# def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
# def me{'scal_result'} = cast_i{DE, acc{'scal_result'} * 0xfefefefefefefeff}
# extend finish{me}
# }
# implicit identity values
local def of_e{[_]E, G} = G{E}
local def of_e{E if isprim{E}, G} = G{E}
def assoc_accumulator{F==min, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, of_e{T, {E} => if (isfloat{E}) E~~1/0 else maxvalue{E}}}
def assoc_accumulator{F==max, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, of_e{T, {E} => if (isfloat{E}) -E~~1/0 else minvalue{E}}}
def assoc_accumulator{F==__add, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, 0}
def any_accumulator{unr, T} = bool_accumulator{|, unr, T, 0}
def all_accumulator{unr, T} = bool_accumulator{&, unr, T, if (scal_bool{T}) 1 else if (isvec{T}) maxvalue{eltype{T}} else assert{0,'bad T for all_accumulator',T} }

View File

@ -1,9 +1,7 @@
include './base' include './base'
include './vecfold' include './vecfold'
include './mask'
if_inline (hasarch{'SSE2'}) { include './accumulator'
def fold_addw{v:V} = vfold{+, fold{+, mzip128{v, V**0}}}
}
def vec_merge_shift_right{a:V=[n]_, b:V, s if hasarch{'SSE2'} and not hasarch{'SSSE3'}} = { def vec_merge_shift_right{a:V=[n]_, b:V, s if hasarch{'SSE2'} and not hasarch{'SSSE3'}} = {
vec_shift_left{a, n-s} | vec_shift_right{b, s} vec_shift_left{a, n-s} | vec_shift_right{b, s}
@ -14,88 +12,54 @@ def vec_merge_shift_right{a:V, b:V, 1 if width{V}>128} = {
vec_merge_shift_right_128{p, b, 1} vec_merge_shift_right_128{p, b, 1}
} }
def __add{a:(usz), b:(u1)} = a + promote{usz,b}
def __lt{a:V=[_]_, b if knum{b}} = a < V**b def __lt{a:V=[_]_, b if knum{b}} = a < V**b
def __eq{a:V=[_]_, b if knum{b}} = a == V**b def __eq{a:V=[_]_, b if knum{b}} = a == V**b
def group_statistics{T} = { def group_statistics{T} = {
def store{p:(*u8), 0, b:(u1)} = store{p, 0, promote{u8, b}} def usz_accumulator = count_accumulator{usz, ...}
def max_accumulator = assoc_accumulator{max, ..., -1}
def widen_sum = width{T} <= 8 def {types, acc_gen, ops} = each{tup,
def sum_vec = if (widen_sum) fold_addw else vfold{+, .} tup{u8, any_accumulator, {_,w} => w < -1}, # bad
tup{usz, usz_accumulator, {_,w} => w == -1}, # neg
tup{u8, all_accumulator, {p,w} => p <= w }, # sort
tup{usz, usz_accumulator, {p,w} => p != w }, # change
tup{T, max_accumulator, {_,w} => w } # max
}
def var{op, get} = { fn group_statistics{T}(w:*void, xn:usz, outs:each{__pnt,types}) : void = {
# Identity, type def w = *T~~w
def id = match (op) { {(max)} => -1; {(&)} => 1; {_} => 0 } def accs = if (has_simd) {
def S = match (op) { {(max)} => T; {(+)} => usz; {_} => u1 } def bulk = arch_defvw/width{T}
def V = [bulk]T
def VU = ty_u{V}
# Scalar accumulator def unr = 2
def updater{v,op}{...a} = { v = op{v, get{...a}} } def accs = each{{a,T} => a{unr, if (quality{T}=='u') VU else V}, acc_gen, types}
def scal{val} = {
v:S = val
tup{v, updater{v, op}}
}
def scal{} = scal{id}
# Vector accumulator prev_v:V = V ** -1
def vec{l} = { @for_mu{bulk, unr, mu_extra{...accs}}(curr_vs in tup{V,w}, M in 'm' over xn) {
def V = match (S) { {(T)} => [l]T; {_} => [l]ty_u{T} } def prev_vs = shiftright{tup{prev_v}, curr_vs}
v := V**(if (id==1) maxvalue{ty_u{T}} else id) def prev_es = each{vec_merge_shift_right{..., 1}, prev_vs, curr_vs}
def u = updater{v, if (same{op,+}) (-) else op} each{{a, F} => {
def {flush, get} = if (S!=usz) { a{'acc', M, each{F, prev_es, curr_vs}}
def get = match (op) { }, accs, ops}
{(&)} => all_hom prev_v = select{curr_vs,-1}
{(|)} => any_hom
{(max)} => vfold{max, .}
}
tup{{}=>{}, {} => get{v}}
} else {
f:usz = 0
def flush{} = { f += cast_i{usz, sum_vec{v}}; v = V**0 }
tup{flush, {} => f}
} }
tup{u, flush, get} accs
} } else {
tup{if (S==u1) u8 else S, scal, vec} p:T = -1
} def accs = each{{a,T} => a{'!', T}, acc_gen, types}
def {types, init_scal, init_vec} = each{tup, @for (c in w over xn) {
var{|, {_,w} => w < -1}, # bad each{{a, F} => a{'acc', F{p, c}}, accs, ops}
var{+, {_,w} => w == -1}, # neg p = c
var{&, {p,w} => p <= w }, # sort
var{+, {p,w} => p != w }, # change
var{max, {_,w} => w } # max
}
def run{g, ...par} = g{...par}
def runvars{gens, ...par} = each{run{., ...par}, gens}
fn group_statistics(w0:*void, xn:usz, outs:each{__pnt,types}) : void = {
def {start, init} = if (not has_simd) tup{0, tup{}} else {
def vl = arch_defvw/width{T}; def V = [vl]T
def {accum, flush, get} = flip{runvars{init_vec, vl}}
e:= xn / vl
i:usz = 0
prev:V = V**(-1)
while (i < e) {
def lmax = 1 << (width{T}-1 - (not widen_sum)*lb{vl})
l:= min{usz~~lmax, e-i}
@for (w in *V~~w0 + i over l) {
runvars{accum, vec_merge_shift_right{prev, w, 1}, w}
prev = w
}
i+= l
runvars{flush}
} }
tup{e*vl, tup{runvars{get}}} accs
} }
def {vals, accum} = flip{each{run, init_scal, ...init}} def results = each{{a} => a{'vec_result'}, accs}
prev:T = -1 each{{out:*T, r} => store{out, 0, promote{T,r}}, outs, results}
if (start > 0) prev = load{*T~~w0, start-1}
@for (w in *T~~w0 over _ from start to xn) {
runvars{accum, prev,w}
prev = w
}
each{store{.,0,.}, outs, vals}
} }
group_statistics{T}
} }
export{'si_group_statistics_i8', group_statistics{i8}} export{'si_group_statistics_i8', group_statistics{i8}}