use fancy accumulator for getRange

This commit is contained in:
dzaima 2025-03-20 21:34:16 +02:00
parent e1b62b95b6
commit f503880cc3

View File

@ -2,6 +2,7 @@ include './base'
include './mask' include './mask'
include './vecfold' include './vecfold'
include './hashtab' include './hashtab'
include './accumulator'
def find_first{C, M, F, ...v1} = { def find_first{C, M, F, ...v1} = {
def exit = makelabel{} def exit = makelabel{}
@ -296,18 +297,6 @@ export{'simd_deduplicate_u8', simd_deduplicate_u8}
export{'simd_member_u8', simd_member_u8} export{'simd_member_u8', simd_member_u8}
export{'simd_index_tab_u8', simd_index_tab_u8{usz}} export{'simd_index_tab_u8', simd_index_tab_u8{usz}}
def acc{unr, init:T} = {
a0v := init
def a0 = tup{a0v}
def a1 = @collect(unr) { reg:=init }
def op{'get'} = a0v
def op{'tr', F} = { a0v = tree_fold{F, a1} }
def op{'upd', is, F} = {
if (length{is}==1) a0 = F{a0}
else a1 = F{a1}
}
}
# following RangeFn's specification in calls.h, whether it's necessary to return 0 or otherwise accumulating & converting to i64 may produce incorrect results # following RangeFn's specification in calls.h, whether it's necessary to return 0 or otherwise accumulating & converting to i64 may produce incorrect results
def bad_float_i64{x:T=[_](f64)} = { def bad_float_i64{x:T=[_](f64)} = {
a:= abs{x} a:= abs{x}
@ -317,28 +306,26 @@ def bad_float_i64{x:T=[_](f64)} = {
def bad_float_i64{x:T=[_](f64) if hasarch{'SSE4.1'}} = (x!=floor{x}) | (abs{x}>T**(1<<53)) def bad_float_i64{x:T=[_](f64) if hasarch{'SSE4.1'}} = (x!=floor{x}) | (abs{x}>T**(1<<53))
def bad_float_i64{x:T=[_](f64) if hasarch{'AARCH64'}} = x != cvt{f64, cvt{i64, x}} def bad_float_i64{x:T=[_](f64) if hasarch{'AARCH64'}} = x != cvt{f64, cvt{i64, x}}
def mask_blend{b:T, x:T, M} = x
def mask_blend{b:T, x:T, M if M{0}} = blend_hom{b, x, M{T, 'to homogeneous bits'}}
fn getRange{E}(x0:*void, res:*i64, n:u64) : u1 = { fn getRange{E}(x0:*void, res:*i64, n:u64) : u1 = {
assert{n>0} assert{n>0}
x:= *E~~x0 x:= *E~~x0
min1:E = *x min1:= undefined{E}
max1:E = *x max1:= undefined{E}
if (has_simd) { if (has_simd) {
def bulk = arch_defvw/width{E} def bulk = arch_defvw/width{E}
def VT = [bulk]E def VT = [bulk]E
def unr = tern{E==f64 and hasarch{'X86_64'}, 1, 2} def unr = tern{E==f64 and hasarch{'X86_64'}, 1, 2}
def minA = acc{unr, VT**min1} def min_a = assoc_accumulator{min, unr, VT}
def maxA = acc{unr, VT**min1} def max_a = assoc_accumulator{max, unr, VT}
@for_mu{bulk, unr, {} => { minA{'tr',min}; maxA{'tr',max} }}(cx in tup{VT,x}, M in 'm' over is to n) { @for_mu{bulk, unr, mu_extra{min_a,max_a}}(cx in tup{VT,x}, M in 'm' over is to n) {
if (E==f64 and any_hom{M, ...each{bad_float_i64, cx}}) return{0} if (E==f64 and any_hom{M, ...each{bad_float_i64, cx}}) return{0}
minA{'upd', is, {a} => eachx{mask_blend, a, each{min, a, cx}, M}} # blend min_a{'acc', M, cx}
maxA{'upd', is, {a} => eachx{mask_blend, a, each{max, a, cx}, M}} # blend max_a{'acc', M, cx}
} }
min1 = vfold{min, minA{'get'}} min1 = min_a{'vec_result'}
max1 = vfold{max, maxA{'get'}} max1 = max_a{'vec_result'}
} else { } else {
min1 = max1 = *x
@for (x over i to n) { @for (x over i to n) {
if (E==f64 and rare{x != emit{f64, '', emit{i64, '', x}}}) return{0} if (E==f64 and rare{x != emit{f64, '', emit{i64, '', x}}}) return{0}
min1 = min{min1, x} min1 = min{min1, x}