AVX2-based boolean short-row scans

This commit is contained in:
Marshall Lochbaum 2024-06-04 20:35:46 -04:00
parent 6d27dd391b
commit 40bf3bfd1c
2 changed files with 50 additions and 4 deletions

View File

@ -158,7 +158,7 @@ def lvec = match { {[n]T, n, (width{T})} => 1; {T, n, w} => 0 }
# base cases
def {
absu,andAllZero,andnz,b_getBatch,clmul,cvt,extract,fold_addw,half,
absu,andAllZero,andnz,b_getBatch,blend,clmul,cvt,extract,fold_addw,half,
homAll,homAny,homBlend,homMask,homMaskStore,homMaskStoreF,loadBatchBit,
loadLow,make,maskStore,maskToHom,mulw,mulh,narrow,narrowPair,packHi,packLo,packQ,pair,pdep,
pext,popcRand,sel,shl,shr,shuf,shuf16Hi,shuf16Lo,shufHalves,storeLow,

View File

@ -5,6 +5,7 @@ if_inline (hasarch{'X86_64'}) {
}
include './mask'
include './f64'
include 'util/tup'
include './scan_common'
# Initialized scan, generic implementation
@ -300,10 +301,13 @@ export{'si_scan_plus_i32_f64', plus_scanG{i32, f64}}
# Row-wise boolean scan
def aligned_mask{l} = (~u64~~0) / ((u64~~1 << l)-1)
def unaligned_mask{l} = {
def d = 64 % l
def m = (~u64~~0 >> d) / ((u64~~1 << l)-1)
tup{m<<l | 1, d}
}
def loop_with_unaligned_mask{x, r, nw, l, step} = {
d:usz = 64 % l
m:u64 = (~u64~~0 >> d) / ((u64~~1 << l)-1)
m = m<<l | 1
{m, d} := unaligned_mask{l}
c:u64 = 0 # carry (initial value never matters)
@for (x, r over nw) {
match (step{x, c, m}) {
@ -313,6 +317,21 @@ def loop_with_unaligned_mask{x, r, nw, l, step} = {
m = m>>d | m<<(l-d)
}
}
def avx2_loop_with_unaligned_mask{xp, rp, nw, l, scan_words, apply_carry} = {
{ms, d} := unaligned_mask{l}
def V = [4]u64
d4:usz = width{V} % l
m:= make{V, scan{{a,_} => a>>d | a<<(l-d), tup{ms, ...iota{3}}}}
c:= V**0
@maskedLoop{4} (x in tup{V, xp},
r in tup{V, rp} over promote{u64,nw}) {
s := scan_words{x, m}
pc:= c; c = shuf{V, -(s>>63), 4b2103}
r = apply_carry{s, blend{V, c, pc, 2b0001}, (m-V**1)&~m}
m = m>>d4 | m<<(l-d4)
}
}
fn scan_rows_andor{id}(src:*u64, dst:*u64, n:usz, l:usz) : void = {
def qand = not id
assert{l > 0}
@ -337,6 +356,17 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, n:usz, l:usz) : void = {
}
# could use for l>=8; not much faster and takes up space
# def rowwise{T} = @for (r in *T~~dst, x in *T~~src over (64/width{T})*nw) r = x &~ (x+1)
} else if (hasarch{'AVX2'}) {
def scan_words{x, m:V} = {
mb:= m | V**1
p:= if (qand) (x &~ m) >> 1 else ~(x | m) >> 1
a:= if (qand) p + (mb & x) else p + (mb &~ x)
if (qand) p ^ a else ~(p ^ a)
}
def apply_carry{s, c, f} = {
if (qand) s & (~f | c) else s | (f & c)
}
avx2_loop_with_unaligned_mask{src, dst, nw, l, scan_words, apply_carry}
} else {
loop_with_unaligned_mask{src, dst, nw, l, {x, c, m} => {
s:= (if (qand) (x &~ m) >> 1 else ~(x | m) >> 1 )
@ -398,6 +428,18 @@ fn scan_rows_neq(x:*u64, r:*u64, n:usz, l:usz) : void = {
b:= s<<1 & m # last bit of previous row
r = s ^ (b<<l - b)
}
} else if (hasarch{'AVX2'}) {
def scan_words{x, m} = {
def vec_prefix_byshift{op, sh} = {
def pre{v:V, k} = if (k < elwidth{V}) pre{op{v, sh{v,k}}, 2*k} else v
{v:T} => pre{v, 1}
}
s:= vec_prefix_byshift{^, <<}{x}
b:= s<<1 & m # last bit of previous row
s ^ (b<<l - b)
}
def apply_carry{s, c, f} = s ^ (f & c)
avx2_loop_with_unaligned_mask{x, r, nw, l, scan_words, apply_carry}
} else {
loop_with_unaligned_mask{x, r, nw, l, {x, c, m} => {
s:= scan_word{x}
@ -432,6 +474,10 @@ fn scan_rows_left(x:*u64, r:*u64, n:usz, l:usz) : void = {
if ((l & (l-1)) == 0) {
m:u64 = aligned_mask{l}
@for (r, x over nw) { b:= x & m; r = b<<l - b }
} else if (hasarch{'AVX2'}) {
def scan_words{x, m} = { b:= x&m; b<<l - b }
def apply_carry{s, c, f} = s | (f & c)
avx2_loop_with_unaligned_mask{x, r, nw, l, scan_words, apply_carry}
} else {
loop_with_unaligned_mask{x, r, nw, l, {x, c, m} => {
f:= (m-1)&~m # bits before first full row