support requesting a flush every k iterations in for_mu

This commit is contained in:
dzaima 2025-03-20 20:38:28 +02:00
parent 9c2ea18e22
commit 3b1239d499

View File

@ -122,29 +122,59 @@ def for_masked_pos{bulk}{vars,begin==0,end:L,iter} = {
# masked unrolled loop
# bulk: vector count
# unr: unroll amount
# fromunr (optional): {}=>{transition from unrolled to non-unrolled}
# extra (optional):
# either fromunr - {}=>{transition from unrolled to non-unrolled}
# or tup{fromunr, flush_max, {}=>{flush}}, where every flush_max iter calls flush must be called
# loop args:
# begin must be 0
# end is scalar element count
# index given is a tuple of batch indexes to process
def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = {
def for_mu{bulk, unr, extra}{vars,begin==0,end,iter} = {
def {fromunr, flush_max, flush} = match (extra) {
{a if kgen{a}} => tup{a, 1/0, {}=>{}}
{{a, b, c}} => tup{a, b, c}
}
l:u64 = promote{u64, end}
m:u64 = l / bulk
if (unr==1) {
@for (i from 0 to m) ml_exec{tup{i}, iter, vars, bulk, mask_none}
left:= l & (bulk-1)
if (left!=0) ml_exec{tup{m}, iter, vars, bulk, mask_first{left}}
if (same{flush_max, 1/0}) {
@for (i from 0 to m) ml_exec{tup{i}, iter, vars, bulk, mask_none}
} else if (m > 0) {
def done = makelabel{}
cs:u64 = 0
while (1) {
def ce = min{m, cs + flush_max-1}
@for (i from cs to ce) ml_exec{tup{i}, iter, vars, bulk, mask_none}
if (ce == m) goto{done}
cs = ce
flush{}
}
setlabel{done}
}
} else {
if (m > 0) {
i:u64 = 0
if (unr <= m) {
while ((i+unr) <= m) {
if (m >= unr) {
def unr_iter{} = {
def is = each{{j}=>i+j, iota{unr}}
ml_exec{is, iter, vars, bulk, mask_none}
i+= unr
}
if (same{flush_max, 1/0}) {
while (i+unr <= m) unr_iter{}
} else {
def done = makelabel{}
def unr_end = m-unr + 1
while (1) {
def ce = min{unr_end, i + (flush_max-unr)*unr}
while (i < ce) unr_iter{}
if (ce == unr_end) goto{done}
flush{}
}
setlabel{done}
}
fromunr{}
}
if (unr==2) {
@ -153,9 +183,9 @@ def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = {
@for(j from i to m) ml_exec{tup{j}, iter, vars, bulk, mask_none}
}
}
left:= l & (bulk-1)
if (left!=0) ml_exec{tup{m}, iter, vars, bulk, mask_first{left}}
}
left:= l & (bulk-1)
if (left!=0) ml_exec{tup{m}, iter, vars, bulk, mask_first{left}}
}
def for_mu{bulk, unr} = for_mu{bulk, unr, {}=>0}