Merge pull request #117 from mlochbaum/adaptive-minmax-scan

Adaptive min/max scan
This commit is contained in:
dzaima 2024-08-20 23:40:54 +03:00 committed by GitHub
commit 413fb8893b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 4 deletions

View File

@ -11,6 +11,7 @@
// =≤≥>- in terms of ≠<∨∧+ with adjustments
// Arithmetic operand, rank 1:
// ⌈⌊ Scalar, SSE, AVX in log(vector width) steps (SHOULD add NEON)
// Check in 6-vector blocks to quickly write result if constant
// + Overflow-checked scalar or AVX2
// Ad-hoc boolean-valued handling for ≠∨
// SHOULD extend rank 1 special cases to cell bound 1

View File

@ -23,20 +23,42 @@ def scan_loop{init, x:*T, r:*T, len:(u64), scan, scan_last} = {
q:= len & (step-1)
if (q!=0) homMaskStoreF{rv+e, maskOf{V, q}, scan_last{load{xv,e}, p}}
}
def scan_post{init, x:*T, r:*T, len:(u64), op, pre} = {
def get_scan_last{op, pre} = {
def last{v, p} = op{pre{v}, p}
def scan{v, p} = {
n:= last{v, p}
p = toLast{n}
n
}
scan_loop{init, x, r, len, scan, last}
tup{scan, last}
}
# Associative scan ?` if a?b?a = a?b = b?a, used for ⌊⌈
def scan_idem = scan_scal
fn scan_idem{T, op if hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
scan_post{init, x, r, len, op, make_scan_idem{T, op}}
def {scan, last} = get_scan_last{op, make_scan_idem{T, op}}
def cmp = match (op) { {(min)} => (>); {(max)} => (<) }
def step = arch_defvw/width{T}
def V = [step]T
p:= V**init
xv:= *V ~~ x
rv:= *V ~~ r
e:= len/step
# Check k vectors at a time to see if they can all be ignored
def k = 6
ek := e / k
@for (ik to ek) { i := ik * k
def ii = iota{k}
xvi := each{load{xv + i, .}, ii}
if (not homAny{cmp{p, tree_fold{op, xvi}}}) {
each{store{rv+i, ., p}, ii}
} else @unroll (rv in rv+i over j to k) {
rv = scan{select{xvi,j}, p}
}
}
@for (xv, rv over _ from ek*k to e) rv = scan{xv,p}
q:= len & (step-1)
if (q!=0) homMaskStoreF{rv+e, maskOf{V, q}, last{load{xv,e}, p}}
}
export{'si_scan_min_init_i8', scan_idem{i8 , min}}; export{'si_scan_max_init_i8', scan_idem{i8 , max}}
@ -67,7 +89,7 @@ def scan_plus = scan_assoc{+}
def scan_assoc_0 = scan_scal
fn scan_assoc_0{T, op if hasarch{'X86_64'}}(x:*T, r:*T, len:u64, init:T) : void = {
# Prefix op on entire AVX register
scan_post{init, x, r, len, op, scan_plus}
scan_loop{init, x, r, len, ...get_scan_last{op, scan_plus}}
}
export{'si_scan_pluswrap_u8', scan_assoc_0{u8 , +}}
export{'si_scan_pluswrap_u16', scan_assoc_0{u16, +}}