move groupstat to fancy accumulator
This commit is contained in:
parent
97260e0051
commit
05d87dd7df
@ -10,6 +10,9 @@
|
||||
# acc{'vec_result'} - get result of vector accumulates (i.e. 'to_scal' + 'scal_result')
|
||||
# acc{'vec_result', n} - 'vec_result', passing n to 'to_scal'
|
||||
|
||||
|
||||
|
||||
# general things
|
||||
local def extend finish{me} = {
|
||||
def me{'vec_result', ...rest} = {
|
||||
me{'to_scal', ...rest}
|
||||
@ -17,7 +20,6 @@ local def extend finish{me} = {
|
||||
}
|
||||
}
|
||||
local def int_els{DE, SE} = isint{DE} and isint{SE} and quality{DE}==quality{SE} and DE>=SE
|
||||
|
||||
def mu_extra{...args} = {
|
||||
def get_all{prop} = each{{c}=>c{prop}, args}
|
||||
tup{
|
||||
@ -27,6 +29,28 @@ def mu_extra{...args} = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
# scalar accumulators
|
||||
local def scalar_acc_impl{F, T, ident} = {
|
||||
acc:T = ident
|
||||
def me{'acc', v} = {
|
||||
acc = F{acc, v}
|
||||
}
|
||||
def me{'vec_result'} = me{'scal_result'}
|
||||
def me{'scal_result'} = acc
|
||||
}
|
||||
local def scal_bool{T} = is{T,'!'} or isunsigned{T}
|
||||
|
||||
local def add_promote{a:A,b} = a + promote{A, b}
|
||||
def assoc_accumulator{F, '!', T if isprim{T}, ident} = scalar_acc_impl{F, T, ident}
|
||||
def count_accumulator{DE, '!', T if scal_bool{T} and isunsigned{DE}} = scalar_acc_impl{add_promote, DE, 0}
|
||||
def bool_accumulator {F, '!', T if scal_bool{T}, ident} = scalar_acc_impl{F, u1, ident}
|
||||
def sum_accumulator {DE, '!', T if isprim{T}} = scalar_acc_impl{add_promote, DE, 0}
|
||||
|
||||
|
||||
|
||||
# vector accumulators
|
||||
local def mask_ident{A, ident}{M, acc_vec, v:V=[_]_} = {
|
||||
if (M{0}==0) A{acc_vec, v}
|
||||
else if (is{ident,0}) A{acc_vec, M{v}}
|
||||
@ -57,7 +81,7 @@ local def acc_impl{A if kgen{A}, M if kgen{M}, unr if knum{unr}, VT=[k]E, S, ini
|
||||
def me{'from_unr'} = {
|
||||
acc_vec = tree_fold{M, acc_tup}
|
||||
}
|
||||
def me{'to_scal', ...rest} = {
|
||||
def me{'to_scal', ..._} = {
|
||||
acc_scal = vfold{M, acc_vec}
|
||||
}
|
||||
def me{'flush_min'} = 1/0
|
||||
@ -71,12 +95,16 @@ def assoc_accumulator{F if kgen{F}, unr if knum{unr}, VT=[_]E, ident} = {
|
||||
me
|
||||
}
|
||||
|
||||
def assoc_accumulator{F==min, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) E~~1/0 else maxvalue{E}}
|
||||
def assoc_accumulator{F==max, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, if (isfloat{E}) -E~~1/0 else minvalue{E}}
|
||||
def assoc_accumulator{F==__add, unr, VT=[_]E} = assoc_accumulator{F, unr, VT, 0}
|
||||
def bool_accumulator{F, unr, VT=[k]SE, ident if isunsigned{SE}} = {
|
||||
def acc = assoc_accumulator{F, unr, VT, ident}
|
||||
def me{...} = acc
|
||||
def me{'scal_result'} = cast_i{u1, acc{'scal_result'}}
|
||||
extend finish{me}
|
||||
}
|
||||
|
||||
|
||||
|
||||
# aarch64-specific accumulators
|
||||
def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'AARCH64'}} = {
|
||||
def A = if (width{DE}>width{ux}) DE else primtype{quality{DE}, width{ux}}
|
||||
def VM = el_m{VT}
|
||||
@ -90,12 +118,8 @@ def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'A
|
||||
def ps = each{if (exact) vfold{+,.} else fold_addw, vs}
|
||||
acc_scal+= tree_fold{+, each{promote{A,.}, ps}}
|
||||
}
|
||||
def me{'from_unr'} = {
|
||||
into_scal{slice{acc_tup, 1}}
|
||||
}
|
||||
def me{'to_scal', ...rest} = {
|
||||
into_scal{tup{acc_vec}}
|
||||
}
|
||||
def me{'from_unr'} = into_scal{slice{acc_tup, 1}}
|
||||
def me{'to_scal', ..._} = into_scal{tup{acc_vec}}
|
||||
def me{'flush_min'} = if (exact) {
|
||||
1/0
|
||||
} else {
|
||||
@ -112,7 +136,7 @@ def sum_accumulator{DE, unr, VT=[_]SE if int_els{DE,SE} and DE>SE and hasarch{'A
|
||||
|
||||
|
||||
|
||||
# TODO: AVX-512 could use dpbusd/dpwssd
|
||||
# x86_64-specific accumulators; TODO: AVX-512 could use dpbusd/dpwssd
|
||||
def sum_accumulator{DE, unr, VT=[k]SE if int_els{DE,SE} and width{SE}==8 and hasarch{'X86_64'}} = {
|
||||
def VM = [k/8]u64
|
||||
def VU = [k]u8
|
||||
@ -167,7 +191,7 @@ def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO
|
||||
def hd = promote{DE, hs}<<16
|
||||
acc+= ld + hd
|
||||
}
|
||||
def me{'to_scal', ...rest} = {
|
||||
def me{'to_scal', ..._} = {
|
||||
into_scal{}
|
||||
}
|
||||
def me{'flush_min'} = 65536/(unr*k)
|
||||
@ -183,9 +207,7 @@ def sum_accumulator{DE==i64, unr, VT=[k]SE==i32 if hasarch{'X86_64'}} = { # TODO
|
||||
}
|
||||
|
||||
def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8), ident if hasarch{'X86_64'} and not hasarch{'SSE4.1'}} = {
|
||||
def adda{M, a, c:([16]i8)} = {
|
||||
mask_ident{F,ident}{M, a, ty_u{c} ^ [16]u8**128}
|
||||
}
|
||||
def adda{M, a, c:([16]i8)} = mask_ident{F,ident}{M, a, ty_u{c} ^ [16]u8**128}
|
||||
def {acc, ..._} = acc_impl{adda, F, unr, [16]u8, u8, (ident%256) ^ 128}
|
||||
def me{...} = acc
|
||||
def me{'scal_result'} = i8~~(acc{'scal_result'} ^ 128)
|
||||
@ -193,3 +215,61 @@ def assoc_accumulator{F if is{F,min} or is{F,max}, unr if knum{unr}, VT=([16]i8)
|
||||
}
|
||||
|
||||
def sum_accumulator{E, unr, VT=[_]E} = assoc_accumulator{__add, unr, VT, 0}
|
||||
|
||||
|
||||
|
||||
def count_accumulator{DE, unr, VT=[k]SE if isunsigned{DE} and int_els{DE,SE}} = {
|
||||
def exact = DE==SE
|
||||
def widen_sum = SE==u8 and DE>u8
|
||||
|
||||
def {acc, acc_scal, acc_tup, acc_vec} = acc_impl{mask_ident{-,0}, +, unr, VT, DE, 0}
|
||||
def into_scal{vs} = {
|
||||
def curr = if (widen_sum) {
|
||||
def op{v if hasarch{'X86_64'}} = absdiff_sum{8, v, VT**0}
|
||||
def op{v if hasarch{'AARCH64'}} = {
|
||||
assert{unr*k <= 256}
|
||||
addpw{v}
|
||||
}
|
||||
cast_i{DE, vfold{+, tree_fold{+, each{op, vs}}}}
|
||||
} else {
|
||||
promote{DE, tree_fold{+, each{vfold{+,.}, vs}}}
|
||||
}
|
||||
acc_scal+= curr
|
||||
}
|
||||
|
||||
def me{...} = acc
|
||||
def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
|
||||
def me{'from_unr'} = into_scal{slice{acc_tup, 1}}
|
||||
def me{'to_scal', ..._} = into_scal{tup{acc_vec}}
|
||||
def me{'flush_min'} = {
|
||||
if (exact) 1/0
|
||||
else if (widen_sum) maxvalue{SE}
|
||||
else __floor{maxvalue{SE}/(unr*k)}
|
||||
}
|
||||
def me{'flush'} = if (not exact) {
|
||||
into_scal{acc_tup}
|
||||
acc_tup = VT**0
|
||||
}
|
||||
extend finish{me}
|
||||
}
|
||||
|
||||
# def count_accumulator{DE, unr, VT=[_]SE==u8 if int_els{DE,SE} and DE>SE} = { # takes 2 instrs in core loop on x86; and still needs flushing on NEON
|
||||
# def acc = sum_accumulator{u64, unr, VT}
|
||||
# def me{...} = acc
|
||||
# extend perv1{__neg}
|
||||
# def me{'acc', ...M, v} = acc{'acc', ...M, assert_hom{v}}
|
||||
# def me{'scal_result'} = cast_i{DE, acc{'scal_result'} * 0xfefefefefefefeff}
|
||||
# extend finish{me}
|
||||
# }
|
||||
|
||||
|
||||
|
||||
# implicit identity values
|
||||
local def of_e{[_]E, G} = G{E}
|
||||
local def of_e{E if isprim{E}, G} = G{E}
|
||||
def assoc_accumulator{F==min, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, of_e{T, {E} => if (isfloat{E}) E~~1/0 else maxvalue{E}}}
|
||||
def assoc_accumulator{F==max, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, of_e{T, {E} => if (isfloat{E}) -E~~1/0 else minvalue{E}}}
|
||||
def assoc_accumulator{F==__add, unr, T if ktyp{T}} = assoc_accumulator{F, unr, T, 0}
|
||||
|
||||
def any_accumulator{unr, T} = bool_accumulator{|, unr, T, 0}
|
||||
def all_accumulator{unr, T} = bool_accumulator{&, unr, T, if (scal_bool{T}) 1 else if (isvec{T}) maxvalue{eltype{T}} else assert{0,'bad T for all_accumulator',T} }
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
include './base'
|
||||
include './vecfold'
|
||||
|
||||
if_inline (hasarch{'SSE2'}) {
|
||||
def fold_addw{v:V} = vfold{+, fold{+, mzip128{v, V**0}}}
|
||||
}
|
||||
include './mask'
|
||||
include './accumulator'
|
||||
|
||||
def vec_merge_shift_right{a:V=[n]_, b:V, s if hasarch{'SSE2'} and not hasarch{'SSSE3'}} = {
|
||||
vec_shift_left{a, n-s} | vec_shift_right{b, s}
|
||||
@ -14,88 +12,54 @@ def vec_merge_shift_right{a:V, b:V, 1 if width{V}>128} = {
|
||||
vec_merge_shift_right_128{p, b, 1}
|
||||
}
|
||||
|
||||
def __add{a:(usz), b:(u1)} = a + promote{usz,b}
|
||||
def __lt{a:V=[_]_, b if knum{b}} = a < V**b
|
||||
def __eq{a:V=[_]_, b if knum{b}} = a == V**b
|
||||
|
||||
def group_statistics{T} = {
|
||||
def store{p:(*u8), 0, b:(u1)} = store{p, 0, promote{u8, b}}
|
||||
|
||||
def widen_sum = width{T} <= 8
|
||||
def sum_vec = if (widen_sum) fold_addw else vfold{+, .}
|
||||
|
||||
def var{op, get} = {
|
||||
# Identity, type
|
||||
def id = match (op) { {(max)} => -1; {(&)} => 1; {_} => 0 }
|
||||
def S = match (op) { {(max)} => T; {(+)} => usz; {_} => u1 }
|
||||
|
||||
# Scalar accumulator
|
||||
def updater{v,op}{...a} = { v = op{v, get{...a}} }
|
||||
def scal{val} = {
|
||||
v:S = val
|
||||
tup{v, updater{v, op}}
|
||||
}
|
||||
def scal{} = scal{id}
|
||||
|
||||
# Vector accumulator
|
||||
def vec{l} = {
|
||||
def V = match (S) { {(T)} => [l]T; {_} => [l]ty_u{T} }
|
||||
v := V**(if (id==1) maxvalue{ty_u{T}} else id)
|
||||
def u = updater{v, if (same{op,+}) (-) else op}
|
||||
def {flush, get} = if (S!=usz) {
|
||||
def get = match (op) {
|
||||
{(&)} => all_hom
|
||||
{(|)} => any_hom
|
||||
{(max)} => vfold{max, .}
|
||||
}
|
||||
tup{{}=>{}, {} => get{v}}
|
||||
} else {
|
||||
f:usz = 0
|
||||
def flush{} = { f += cast_i{usz, sum_vec{v}}; v = V**0 }
|
||||
tup{flush, {} => f}
|
||||
def usz_accumulator = count_accumulator{usz, ...}
|
||||
def max_accumulator = assoc_accumulator{max, ..., -1}
|
||||
|
||||
def {types, acc_gen, ops} = each{tup,
|
||||
tup{u8, any_accumulator, {_,w} => w < -1}, # bad
|
||||
tup{usz, usz_accumulator, {_,w} => w == -1}, # neg
|
||||
tup{u8, all_accumulator, {p,w} => p <= w }, # sort
|
||||
tup{usz, usz_accumulator, {p,w} => p != w }, # change
|
||||
tup{T, max_accumulator, {_,w} => w } # max
|
||||
}
|
||||
|
||||
fn group_statistics{T}(w:*void, xn:usz, outs:each{__pnt,types}) : void = {
|
||||
def w = *T~~w
|
||||
def accs = if (has_simd) {
|
||||
def bulk = arch_defvw/width{T}
|
||||
def V = [bulk]T
|
||||
def VU = ty_u{V}
|
||||
|
||||
def unr = 2
|
||||
def accs = each{{a,T} => a{unr, if (quality{T}=='u') VU else V}, acc_gen, types}
|
||||
|
||||
prev_v:V = V ** -1
|
||||
@for_mu{bulk, unr, mu_extra{...accs}}(curr_vs in tup{V,w}, M in 'm' over xn) {
|
||||
def prev_vs = shiftright{tup{prev_v}, curr_vs}
|
||||
def prev_es = each{vec_merge_shift_right{..., 1}, prev_vs, curr_vs}
|
||||
each{{a, F} => {
|
||||
a{'acc', M, each{F, prev_es, curr_vs}}
|
||||
}, accs, ops}
|
||||
prev_v = select{curr_vs,-1}
|
||||
}
|
||||
tup{u, flush, get}
|
||||
}
|
||||
tup{if (S==u1) u8 else S, scal, vec}
|
||||
}
|
||||
def {types, init_scal, init_vec} = each{tup,
|
||||
var{|, {_,w} => w < -1}, # bad
|
||||
var{+, {_,w} => w == -1}, # neg
|
||||
var{&, {p,w} => p <= w }, # sort
|
||||
var{+, {p,w} => p != w }, # change
|
||||
var{max, {_,w} => w } # max
|
||||
}
|
||||
def run{g, ...par} = g{...par}
|
||||
def runvars{gens, ...par} = each{run{., ...par}, gens}
|
||||
|
||||
fn group_statistics(w0:*void, xn:usz, outs:each{__pnt,types}) : void = {
|
||||
def {start, init} = if (not has_simd) tup{0, tup{}} else {
|
||||
def vl = arch_defvw/width{T}; def V = [vl]T
|
||||
def {accum, flush, get} = flip{runvars{init_vec, vl}}
|
||||
e:= xn / vl
|
||||
i:usz = 0
|
||||
prev:V = V**(-1)
|
||||
while (i < e) {
|
||||
def lmax = 1 << (width{T}-1 - (not widen_sum)*lb{vl})
|
||||
l:= min{usz~~lmax, e-i}
|
||||
@for (w in *V~~w0 + i over l) {
|
||||
runvars{accum, vec_merge_shift_right{prev, w, 1}, w}
|
||||
prev = w
|
||||
}
|
||||
i+= l
|
||||
runvars{flush}
|
||||
accs
|
||||
} else {
|
||||
p:T = -1
|
||||
def accs = each{{a,T} => a{'!', T}, acc_gen, types}
|
||||
@for (c in w over xn) {
|
||||
each{{a, F} => a{'acc', F{p, c}}, accs, ops}
|
||||
p = c
|
||||
}
|
||||
tup{e*vl, tup{runvars{get}}}
|
||||
accs
|
||||
}
|
||||
def {vals, accum} = flip{each{run, init_scal, ...init}}
|
||||
prev:T = -1
|
||||
if (start > 0) prev = load{*T~~w0, start-1}
|
||||
@for (w in *T~~w0 over _ from start to xn) {
|
||||
runvars{accum, prev,w}
|
||||
prev = w
|
||||
}
|
||||
each{store{.,0,.}, outs, vals}
|
||||
def results = each{{a} => a{'vec_result'}, accs}
|
||||
each{{out:*T, r} => store{out, 0, promote{T,r}}, outs, results}
|
||||
}
|
||||
group_statistics{T}
|
||||
}
|
||||
|
||||
export{'si_group_statistics_i8', group_statistics{i8}}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user