diff --git a/src/singeli/src/mask.singeli b/src/singeli/src/mask.singeli index 5c614051..4f962444 100644 --- a/src/singeli/src/mask.singeli +++ b/src/singeli/src/mask.singeli @@ -122,29 +122,59 @@ def for_masked_pos{bulk}{vars,begin==0,end:L,iter} = { # masked unrolled loop # bulk: vector count # unr: unroll amount -# fromunr (optional): {}=>{transition from unrolled to non-unrolled} +# extra (optional): +# either fromunr - {}=>{transition from unrolled to non-unrolled} +# or tup{fromunr, flush_max, {}=>{flush}}, where every flush_max iter calls flush must be called # loop args: # begin must be 0 # end is scalar element count # index given is a tuple of batch indexes to process -def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = { +def for_mu{bulk, unr, extra}{vars,begin==0,end,iter} = { + def {fromunr, flush_max, flush} = match (extra) { + {a if kgen{a}} => tup{a, 1/0, {}=>{}} + {{a, b, c}} => tup{a, b, c} + } + l:u64 = promote{u64, end} m:u64 = l / bulk if (unr==1) { - @for (i from 0 to m) ml_exec{tup{i}, iter, vars, bulk, mask_none} - - left:= l & (bulk-1) - if (left!=0) ml_exec{tup{m}, iter, vars, bulk, mask_first{left}} + if (same{flush_max, 1/0}) { + @for (i from 0 to m) ml_exec{tup{i}, iter, vars, bulk, mask_none} + } else if (m > 0) { + def done = makelabel{} + cs:u64 = 0 + while (1) { + def ce = min{m, cs + flush_max-1} + @for (i from cs to ce) ml_exec{tup{i}, iter, vars, bulk, mask_none} + if (ce == m) goto{done} + cs = ce + flush{} + } + setlabel{done} + } } else { if (m > 0) { i:u64 = 0 - if (unr <= m) { - while ((i+unr) <= m) { + if (m >= unr) { + def unr_iter{} = { def is = each{{j}=>i+j, iota{unr}} ml_exec{is, iter, vars, bulk, mask_none} i+= unr } + if (same{flush_max, 1/0}) { + while (i+unr <= m) unr_iter{} + } else { + def done = makelabel{} + def unr_end = m-unr + 1 + while (1) { + def ce = min{unr_end, i + (flush_max-unr)*unr} + while (i < ce) unr_iter{} + if (ce == unr_end) goto{done} + flush{} + } + setlabel{done} + } fromunr{} } if (unr==2) { @@ -153,9 +183,9 @@ def for_mu{bulk, unr, fromunr}{vars,begin==0,end,iter} = { @for(j from i to m) ml_exec{tup{j}, iter, vars, bulk, mask_none} } } - - left:= l & (bulk-1) - if (left!=0) ml_exec{tup{m}, iter, vars, bulk, mask_first{left}} } + + left:= l & (bulk-1) + if (left!=0) ml_exec{tup{m}, iter, vars, bulk, mask_first{left}} } def for_mu{bulk, unr} = for_mu{bulk, unr, {}=>0}