From 7284eb7010d65186a7ff48a0b8782361348e12b5 Mon Sep 17 00:00:00 2001 From: dzaima Date: Sun, 22 Jan 2023 22:43:10 +0200 Subject: [PATCH] more manual unrolling in dyarith.singeli --- src/singeli/src/base.singeli | 1 + src/singeli/src/dyarith.singeli | 85 ++++++++++++++++++++++----------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/singeli/src/base.singeli b/src/singeli/src/base.singeli index 2c876bff..d1108396 100644 --- a/src/singeli/src/base.singeli +++ b/src/singeli/src/base.singeli @@ -33,6 +33,7 @@ def bit {k,x} = x & (1< [16]i16 ~~ ((v<<8)>>8 != v), rp} if (M{0}) { # masked check - tup{packQ{rp}, homAny{M{packQ{bad}}}} + tup{packQ{rp}, tup{'homAny', M{packQ{bad}}}} } else { # unmasked check; can do check in a simpler way - tup{packQ{rp}, homAny{tupsel{0,bad}|tupsel{1,bad}}} + tup{packQ{rp}, tup{'homAny', tupsel{0,bad}|tupsel{1,bad}}} } } def arithChk2{F, M, w:T, x:T & match{F,__mul} & isvec{T} & i16==eltype{T} & hasarch{'X86_64'}} = { rl:= __mul {w,x} rh:= __mulhi{w,x} - tup{rl, anyne{rh, rl>>15, M}} + tup{rl, tup{'anyne', rh, rl>>15}} } def arithChk2{F, M, w:T, x:T & match{F,__mul} & isvec{T} & i32==eltype{T} & hasarch{'X86_64'}} = { max:= [8]f32 ~~ [8]u32**0x4efffffe def cf32{x} = emit{[8]f32, '_mm256_cvtepi32_ps', x} f32mul:= cf32{w} * cf32{x} - tup{w*x, homAny{M{abs{f32mul} >= max}}} + tup{w*x, tup{'homAny', M{abs{f32mul} >= max}}} # TODO fallback to the below if the above fails # TODO don't do this, but instead shuffle one half, do math, unshuffle that half # def wp = unpackQ{w, T**0} @@ -69,14 +70,14 @@ def arithChk2{F, M, w:T, x:T & match{F,__mul} & isvec{T} & i32==eltype{T} & hasa # def bad = each{{v} => { # ((T2~~v + T2**0x80000000) ^ T2**(cast{i64,1}<<63)) > T2**cast_i{i64, (cast{u64,1}<<63) | 0xFFFFFFFF} # }, rp} - # tup{packQQ{each{{v} => v & T2**0xFFFFFFFF, rp}}, homAny{tupsel{0,bad}|tupsel{1,bad}}} this doesn't use M + # tup{packQQ{each{{v} => v & T2**0xFFFFFFFF, rp}}, tup{'homAny', tupsel{0,bad}|tupsel{1,bad}}} this doesn't use M } def arithChk2{F, M, w:T, x:T & match{F,__mul} & isvec{T} & hasarch{'AARCH64'}} = { def r12 = mul12{w, x} rl:= packLo{r12} rh:= packHi{r12} - tup{rl, homAny{M{rh != (rl >> (elwidth{T}-1))}}} + tup{rl, tup{'homAny', M{rh != (rl >> (elwidth{T}-1))}}} } @@ -84,34 +85,57 @@ def arithChk2{F, M, w:T, x:T & match{F,__mul} & isvec{T} & hasarch{'AARCH64'}} = def runner{u, R, F} = { def c = ~u - def run{F, OO, M, w, x} = { show{'todo', c, R, F, w, x}; emit{void,'__builtin_abort'}; w } + def run{F, M, w, x} = { show{'todo', c, R, F, w, x}; emit{void,'__builtin_abort'}; w } - def run{F, OO, M, w:T, x:T & c & R!=u32} = { - def r2 = arithChk2{F, M, w, x} - if (rare{tupsel{1,r2}}) OO{} - tupsel{0,r2} + def run{F, M, w:T, x:T & c & R!=u32} = { + arithChk2{F, M, w, x} } - def run{F, OO, M, w, x & u} = F{w, x} # trivial base implementation + def run{F, M, w, x & u} = tup{F{w, x}, tup{'none'}} # trivial base implementation def toggleTop{x:X} = x ^ X**(1<<(elwidth{X}-1)) - def run{F==__sub, OO, M, w:VU, x:VU & is_u{VU}} = { # 'b'-'a' + def run{F==__sub, M, w:VU, x:VU & is_u{VU}} = { # 'b'-'a' def VS = ty_s{VU} - run{F, OO, M, VS~~toggleTop{w}, VS~~toggleTop{x}} + run{F, M, VS~~toggleTop{w}, VS~~toggleTop{x}} } - def run{F, OO, M, w:VU, x:VS & is_u{VU} & is_s{VS}} = { # 'a'+3, 'a'-3 - toggleTop{VU~~run{F, OO, M, VS~~toggleTop{w}, x}} + def run{F, M, w:VU, x:VS & is_u{VU} & is_s{VS}} = { # 'a'+3, 'a'-3 + def r = run{F, M, VS~~toggleTop{w}, x} + tup{toggleTop{VU~~tupsel{0,r}}, tupsel{1,r}} } - def run{F==__add, OO, M, w:VS, x:VU & is_s{VS} & is_u{VU}} = run{F, OO, M, x, w} # 3+'a' → 'a'+3 + def run{F==__add, M, w:VS, x:VU & is_s{VS} & is_u{VU}} = run{F, M, x, w} # 3+'a' → 'a'+3 - def run{F, OO, M, w:VW, x:VX & c & R==u32 & (match{F,__add} | match{F,__sub})} = { # 'a'+1, 'a'-1 + def run{F, M, w:VW, x:VX & c & R==u32 & (match{F,__add} | match{F,__sub})} = { # 'a'+1, 'a'-1 r:= F{ty_u{w}, ty_u{x}} - if (homAny{M{r > type{r}**1114111}}) OO{} - to_el{R, VW}~~r + tup{to_el{R, VW}~~r, tup{'homAny', M{r > type{r}**1114111}}} } run } +def runChecks_any{F, vals} = { F{tree_fold{|, each{{c}=>tupsel{1,c}, vals}}} } +def runChecks{type=='homAny', vals, M} = runChecks_any{homAny, vals} +def runChecks{type=='topAny', vals, M} = runChecks_any{topAny, vals} +def runChecks{type=='none', vals, M} = 0 +def runChecks{type=='anyne', vals, M} = { + def cols = flip{vals} + def xs = tupsel{1, cols} + def ys = tupsel{2, cols} + if (tuplen{vals}==1) { + anyne{...xs, ...ys, M} + } else { + assert{M{0} == 0} + ~homAll{tree_fold{&, each{==, xs, ys}}} + } +} + +def arithProcess{F, run, overflow, M, is, cw, cx, TY} = { + def r0 = flip{each{{w1, x1} => run{F, M, w1, x1}, cw, cx}} + def values = tupsel{0, r0} + def checks = tupsel{1, r0} + def ctype = tupsel{0,tupsel{0,checks}} + assert{tree_fold{&, each{{c}=>match{ctype, tupsel{0,c}}, checks}}} + if (rare{runChecks{ctype, checks, M}}) overflow{tupsel{0, is}} + each{{c} => TY~~c, values} +} def arithAAimpl{vw, mode, F, W, X, R, w, x, r, len} = { # show{F, mode, W, X, R} @@ -140,10 +164,11 @@ def arithAAimpl{vw, mode, F, W, X, R, w, x, r, len} = { def run = runner{match{overflow, 0}, R, F} - muLoop{bulk, tern{(mode==0) & hasarch{'AARCH64'}, 2, 1}, len, {is, M} => { + def unr = tern{mode==0, 2, 1} # 2x unroll non-overflowing cases; surpresses clang's default unrolling, which unrolls a lot more; 2x appears to be plenty + muLoop{bulk, unr, len, {is, M} => { def cw = loadBatch{*W~~w, is, ty_sc{W, TY}} def cx = loadBatch{*X~~x, is, ty_sc{X, TY}} - storeBatch{*R~~r, is, each{{w1, x1} => TY~~run{F, {} => overflow{tupsel{0, is}}, M, w1, x1}, cw, cx}, M} + storeBatch{*R~~r, is, arithProcess{F, run, overflow, M, is, cw, cx, TY}, M} }} } } @@ -176,9 +201,11 @@ arithSAf{vw, mode, F, swap, W, X, R}(r:*void, w:u64, x:*void, len:u64) : u64 = { def getW{v & W==f64} = interp_f64{v} cw:= ty_sc{W, TY}**getW{w} - maskedLoop{bulk, len, {i, M} => { - cx:= loadBatch{*X~~x, i, ty_sc{X, TY}} - storeBatch{*R~~r, i, TY~~run{F, {} => overflow{i}, M, tern{swap,cx,cw}, tern{swap,cw,cx}}, M} + def unr = tern{mode==2, 2, 1} # same as in arithAAimpl + muLoop{bulk, unr, len, {is, M} => { + def cx = loadBatch{*X~~x, is, ty_sc{X, TY}} + def cws = tuplen{is}**cw + storeBatch{*R~~r, is, arithProcess{F, run, overflow, M, is, tern{swap,cx,cws}, tern{swap,cws,cx}, TY}, M} }} if (mode==1) 0