(hom|top)Mask → (hom|top)_to_int

This commit is contained in:
dzaima 2025-01-26 01:33:54 +02:00
parent e9e1574d28
commit e3b30e5db7
16 changed files with 66 additions and 66 deletions

View File

@ -170,11 +170,11 @@ Homogeneous definitions (i.e. ones with `hom` in their name) assume that each el
- `blend_hom{f:V, t:V, m:mt{V}} : V` - blend by `m`, setting to `f` where `0` and `t` where `1`
- `blend_top{f:V, t:V, m:V} : V` - blend by top bit of `m`
- `blend_bit{f:V, t:V, m:M} : V` - bitwise blend
- `homMask{a:VI} : uint` - integer mask of whether each element is set (assumes each element has all its bits equal)
- `homMask{...vs} : uint` - merged mask of `each{homMask,vs}`
- `topMask{a:VI} : uint` - integer mask of the top bit of each element
- `homMaskX{a:VI} : tup{knum, uint}` - integer mask where each element is represented by `knum` bits (possibly more efficient to calculate than `homMask`)
- `ctzX{tup{knum, uint}}` - count trailing zeroes from a result of `homMaskX`
- `hom_to_int{a:VI} : uint` - integer mask of whether each element is set (assumes each element has all its bits equal)
- `hom_to_int{...vs} : uint` - merged mask of `each{hom_to_int,vs}`
- `top_to_int{a:VI} : uint` - integer mask of the top bit of each element
- `hom_to_int_ext{a:VI} : tup{knum, uint}` - integer mask where each element is represented by `knum` bits (possibly more efficient to calculate than `hom_to_int`)
- `ctz_ext{tup{knum, uint}}` - count trailing zeroes from a result of `hom_to_int_ext`
## Load/store

View File

@ -49,12 +49,12 @@ def shufInd{a:T, b:T=[4]E, {...is} if width{E}==64 and length{is}==4} = T~~shufI
# mask stuff
def andAllZero{x:T, y:T if w256i{T}} = emit{u1, '_mm256_testz_si256', x, y}
def topMask{x:T if w256{T, 32}} = emit{u8, '_mm256_movemask_ps', v2f{x}}
def topMask{x:T if w256{T, 64}} = emit{u8, '_mm256_movemask_pd', v2d{x}}
def homMask{x:T if w256{T}} = topMask{x}
def top_to_int{x:T if w256{T, 32}} = emit{u8, '_mm256_movemask_ps', v2f{x}}
def top_to_int{x:T if w256{T, 64}} = emit{u8, '_mm256_movemask_pd', v2d{x}}
def hom_to_int{x:T if w256{T}} = top_to_int{x}
def any_hom{x:T if w256i{T} and elwidth{T}>=32} = homMask{[8]u32 ~~ x} != 0
def all_hom{x:T if w256i{T} and elwidth{T}>=32} = homMask{[8]u32 ~~ x} == 0xff
def any_hom{x:T if w256i{T} and elwidth{T}>=32} = hom_to_int{[8]u32 ~~ x} != 0
def all_hom{x:T if w256i{T} and elwidth{T}>=32} = hom_to_int{[8]u32 ~~ x} == 0xff
def any_top{x:T=[_]E if w256i{T} and width{E}>=32} = topMask{x} != 0
def all_top{x:T=[k]E if w256i{T} and width{E}>=32} = topMask{x} == (1<<k)-1
def any_top{x:T=[_]E if w256i{T} and width{E}>=32} = top_to_int{x} != 0
def all_top{x:T=[k]E if w256i{T} and width{E}>=32} = top_to_int{x} == (1<<k)-1

View File

@ -33,17 +33,17 @@ def homMaskStoreF{p:*T, m:M, v:T if w256i{M} and elwidth{T}>=32} = topMaskStore{
def homMaskStoreF{p:*T, m:M, v:T if w256i{M} and elwidth{T}<=16 and w256{T,elwidth{M}}} = store{p, 0, blend_hom{load{p}, v, m}}
# mask stuff
def topMask{x:T if w256{T, 8}} = emit{u32, '_mm256_movemask_epi8', x}
def topMask{x:T if w256{T, 16}} = {
msk:u32 = topMask{emit{[32]u8, '_mm256_packs_epi16', x, [16]u16**0}}
def top_to_int{x:T if w256{T, 8}} = emit{u32, '_mm256_movemask_epi8', x}
def top_to_int{x:T if w256{T, 16}} = {
msk:u32 = top_to_int{emit{[32]u8, '_mm256_packs_epi16', x, [16]u16**0}}
(msk&255) | (msk>>8)
}
def any_hom{x:T if w256i{T}} = ~emit{u1, '_mm256_testz_si256', v2i{x}, v2i{x}}
def all_hom{x:T if w256i{T}} = homMask{[32]u8 ~~ x} == 0xffff_ffff
def any_top{x:T if w256i{T}} = topMask{x} != 0
def all_top{x:T=[k]_ if w256i{T}} = topMask{x} == (1<<k)-1
def homMask{a:T, b:T if w256i{T,16}} = homMask{vec_shuffle{[4]u64, packQ{ty_s{a},ty_s{b}}, 0,2,1,3}}
def all_hom{x:T if w256i{T}} = hom_to_int{[32]u8 ~~ x} == 0xffff_ffff
def any_top{x:T if w256i{T}} = top_to_int{x} != 0
def all_top{x:T=[k]_ if w256i{T}} = top_to_int{x} == (1<<k)-1
def hom_to_int{a:T, b:T if w256i{T,16}} = hom_to_int{vec_shuffle{[4]u64, packQ{ty_s{a},ty_s{b}}, 0,2,1,3}}
def any_top{x:T if w256i{T,32}} = ~emit{u1, '_mm256_testz_ps', v2f{x}, v2f{x}}
def any_top{x:T if w256i{T,64}} = ~emit{u1, '_mm256_testz_pd', v2d{x}, v2d{x}}

View File

@ -31,8 +31,8 @@ def maskStore{p:*V, m:[l](u1), v:V=[l]_ if has512e{V}} = {
}
def topMaskReg{x:V=[k]_} = emit{[k]u1, merge{pref{V},'mov',suf{V},'_mask'}, x}
def topMask{x:V=[k]_ if 512==width{V}} = ty_u{k}~~topMaskReg{x}
def homMask{x:V=[_]_ if 512==width{V}} = topMask{x}
def top_to_int{x:V=[k]_ if 512==width{V}} = ty_u{k}~~topMaskReg{x}
def hom_to_int{x:V=[_]_ if 512==width{V}} = top_to_int{x}
def maskToHom{V=[l]_, x:[l](u1)} = emit{V, merge{pref{V},'movm_',suf{V}}, x}

View File

@ -175,8 +175,8 @@ def lvec = match { {[n]T, n, (width{T})} => 1; {T, n, w} => 0 }
def {
absu,andAllZero,andnz,b_getBatch,blend,blend_units,clmul,cvt,extract,fold_addw,half,
all_bit,any_bit,blend_bit,
all_hom,any_hom,blend_hom,homMask,homMaskStore,homMaskStoreF,
all_top,any_top,blend_top,topMask,topMaskStore,topMaskStoreF,
all_hom,any_hom,blend_hom,hom_to_int,homMaskStore,homMaskStoreF,
all_top,any_top,blend_top,top_to_int,topMaskStore,topMaskStoreF,
loadBatchBit,loadLow,make,maskStore,maskToHom,mulw,mulh,narrow,narrowTrunc,narrowPair,
packQ,pair,pdep,pext,popcRand,rbit,rev,sel,shl,shr,shufInd,storeLow,
unord,unzip,vfold,vec_select,vec_shuffle,widen,widenUpper,multishift,
@ -186,19 +186,19 @@ def {
def blend_bit{f:T, t:T, m:M if width{T}==width{M}} = T ~~ ((M~~t & m) | (M~~f &~ m))
def blend_hom{f:T, t:T, m:M} = blend_bit{f, t, m}
def homMaskX{a:T} = tup{1, homMask{a}} # tup{n,mask}; mask with each bit repeated n times
def ctzX{{n,v}} = ctz{v}/n # ctz for a result of homMaskX
def homMask{...vs if length{vs}>1} = {
def hom_to_int_ext{a:T} = tup{1, hom_to_int{a}} # tup{n,mask}; mask with each bit repeated n times
def ctz_ext{{n,v}} = ctz{v}/n # ctz for a result of homMaskX
def hom_to_int{...vs if length{vs}>1} = {
def n = length{vs}
def [k]_ = oneType{vs}
def RT = ty_u{max{8,k*n}}
def sl{...a} = promote{RT, homMask{...slice{vs,...a}}}
def sl{...a} = promote{RT, hom_to_int{...slice{vs,...a}}}
def h = n/2
def lo = sl{0,h}
def hi = sl{h}
(hi << (h*k)) | lo
}
def homMask{x if ktup{x}} = homMask{...x}
def hom_to_int{x if ktup{x}} = hom_to_int{...x}
if_inline (hasarch{'X86_64'}) {
include 'arch/iintrinsic/basic'

View File

@ -143,7 +143,7 @@ def bins_vectab_i8{up, w, wn, x, xn, rp, t0, t, done if hasarch{'AVX2'}} = {
def nb = 256/vl
nu:u8 = 0; def addu{b} = { nu+=cast_i{u8,popc{b}}; b } # Number of uniques
vb := U~~make{[nb](ty_u{vl}),
@collect (t in *V~~t0 over nb) addu{homMask{t > V**0}}
@collect (t in *V~~t0 over nb) addu{hom_to_int{t > V**0}}
}
dup := promote{u64,nu} < wn
# Unique index to w index conversion
@ -290,7 +290,7 @@ def bin_search_vec{prim, T, w:*T, wn, x:*T, xn, rp, maxwn if hasarch{'AVX2'}} =
if (isvec{type{rn}}) store{rnp, 0, rn}
else storeu{rnp, rn}
} else {
def B = ty_u{vl}; out := cast_i{B, homMask{b}}
def B = ty_u{vl}; out := cast_i{B, hom_to_int{b}}
store{*B~~rp, cdiv{j,vl}, out>>((-j)%vl)}
}
}

View File

@ -73,7 +73,7 @@ def any2bit{VT=[k]T, unr, op0, wS, wV, xS, xV, dst:(*u64), len:(ux)} = {
{(__ne) if isint{T}} => __eq
{_} => op0
}
def mask = if (same{op0, op}) homMask else ({x} => ~homMask{x})
def mask = if (same{op0, op}) hom_to_int else ({x} => ~hom_to_int{x})
@forNZ (ri to cdiv{len,bulk}) {
b_setBatch{bulk, dst, ri, mask{each{{j}=>op{wV{xi+j}, xV{xi+j}}, iota{unr}}}}

View File

@ -49,7 +49,7 @@ fn copy{X, R}(r: *void, x: *void, l:u64, xRaw: *void) : void = {
def bulk2 = bulk*unr
xi:ux = 0
@forNZ (i to cdiv{l,bulk2}) {
b_setBatch{bulk2, rp, i, homMask{each{{i} => op{loadBatch{xp, xi+i, XV}}, iota{unr}}}}
b_setBatch{bulk2, rp, i, hom_to_int{each{{i} => op{loadBatch{xp, xi+i, XV}}, iota{unr}}}}
xi+= unr
}
} else if (width{X}<=width{R}) {

View File

@ -87,7 +87,7 @@ fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = {
def bot = 1<<15 - 1
on:usz = 0
@for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{any_top{t}}) {
o := if (hasarch{'X86_64'}) topMask{t} else homMask{t > V**bot}
o := if (hasarch{'X86_64'}) top_to_int{t} else hom_to_int{t > V**bot}
if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1
while (o > 0) {
jv := jv*vl + cast_i{usz, ctz{o}}
@ -157,7 +157,7 @@ def mark_run_ends{x:*T, m:(ux)} = {
@unroll (j to width{ux} / vec) {
def jv = j*vec
def lv{k} = load{*V~~(x + k)}
m |= promote{ux, homMask{lv{jv} != lv{jv+1}}} << jv
m |= promote{ux, hom_to_int{lv{jv} != lv{jv+1}}} << jv
}
}
def inc_marked_runs{x, tab:*T, m, m0} = {

View File

@ -41,8 +41,8 @@ fn equal{W, X}(w:*void, x:*void, l:u64, d:u64) : u1 = {
def T = [bulk]X
def sh{c} = c << (width{X}-1)
def sh{c if X==u8} = T ~~ (re_el{u16,c}<<7)
def mask{x:X if hasarch{'X86_64'}} = topMask{x}
def mask{x:X if hasarch{'AARCH64'}} = homMask{andnz{x, ~T**0}}
def mask{x:X if hasarch{'X86_64'}} = top_to_int{x}
def mask{x:X if hasarch{'AARCH64'}} = hom_to_int{andnz{x, ~T**0}}
# TODO compare with doing the comparison in vector registers
badBits:= T ** ~(X~~1)

View File

@ -22,7 +22,7 @@ def anyneBit{x:T, y:T, M} = ~M{x^y, 'all bits zeroes'}
def anynePositive{x:T, y:T, M if M{0}==0} = anyne{x, y, M}
def anynePositive{x:T, y:T, M if M{0}==1 and isvec{T}} = {
def {n,m} = homMaskX{x==y}
def {n,m} = hom_to_int_ext{x==y}
def E = tern{type{m}==u64, u64, u32}
(promote{E,~m} << (width{E}-M{'count'}*n)) != 0
}

View File

@ -139,14 +139,14 @@ def any_top{x:V if nvec{V}} = fold_min{ty_s{x}}<0
def all_top{x:V if nvec{V}} = fold_max{ty_s{x}}<0
def homMask{x:T=[k]E if nvecu{T} and width{E}>=k} = {
def hom_to_int{x:T=[k]E if nvecu{T} and width{E}>=k} = {
truncBits{k, fold_add{x & make{T, 1<<iota{k}}}}
}
def homMask{x:T=[16]E if width{E}==8} = {
def hom_to_int{x:T=[16]E if width{E}==8} = {
t:= [8]u16~~sel{[16]u8, x, make{[16]u8, tr_iota{3,0,1,2}}}
fold_add{t & make{[8]u16, (1<<iota{8})*0x0101}}
}
def homMask{a:T,b:T=[16]E if width{E}==8} = {
def hom_to_int{a:T,b:T=[16]E if width{E}==8} = {
m:= make{[16]u8, 1<<(iota{16}&7)}
s:= make{[16]u8, (range{16}>>2) | ((range{16}&3)<<2)}
# fold_add{addpw{addpw{addp{ty_u{a}&m, ty_u{b}&m}}}<<make{[4]u32,iota{4}*8}}
@ -156,21 +156,21 @@ def homMask{a:T,b:T=[16]E if width{E}==8} = {
# t:= shrm{l, 4, h} & make{[16]u8, (1<<(range{16}>>2)) * 0x11}
# fold_add{[4]u32~~t}
}
def homMask{a:T,b:T,c:T,d:T=[16]E if width{E}==8} = {
def hom_to_int{a:T,b:T,c:T,d:T=[16]E if width{E}==8} = {
m:= make{[16]u8, 1<<(iota{16}&7)}
t1:= addp{ty_u{a}&m, ty_u{b}&m}
t2:= addp{ty_u{c}&m, ty_u{d}&m}
t3:= addp{t1, t2}
extract{[2]u64~~addp{t3,t3},0}
}
def homMask{...as={a0:[_]E, _, ..._} if width{E}>=32} = homMask{...each{{i}=>narrowPair{select{as,i*2},select{as,i*2+1}}, iota{length{as}/2}}}
def homMask{a:T,b:T=[k]E if k*2<=width{E}} = {
def hom_to_int{...as={a0:[_]E, _, ..._} if width{E}>=32} = hom_to_int{...each{{i}=>narrowPair{select{as,i*2},select{as,i*2+1}}, iota{length{as}/2}}}
def hom_to_int{a:T,b:T=[k]E if k*2<=width{E}} = {
truncBits{k*2, fold_add{shrm{a,width{E}-k,b} & make{T, (1<<iota{k}) | (1<<(iota{k}+k))}}}
}
def andAllZero{x:T, y:T if nveci{T}} = ~any_bit{x&y}
def homMaskX{a:T=[k]E if E!=u64} = {
def hom_to_int_ext{a:T=[k]E if E!=u64} = {
def h = width{E}/2
tup{h, truncBits{k*h, extract{[1]u64~~shrn{el_m{T}~~a, h}, 0}}}
}

View File

@ -128,7 +128,7 @@ fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u
def exor64 = clmul{., sse{1<<64 - 2}, 0}
@for (xv in *V~~x, rv in *V~~r over i to cdiv{nw,vcount{V}}) {
x8 := xor8{xv}
hb := sse{topMask{[64]u8~~x8}}
hb := sse{top_to_int{[64]u8~~x8}}
xh := exor64{hb} # Exclusive xor of high bits (xh ^ hb for inclusive)
xc := xh ^ carry
v := x8 ^ V~~maskToHom{[64]u8, [64]u1~~extract{xc,0}}

View File

@ -23,7 +23,7 @@ def search{E, x, n:(u64), OP} = {
def VT = [bulk]E
def end = makeBranch{
tup{u64, ty_u{VT}},
{i,c} => return{i*bulk + promote{u64, ctzX{homMaskX{c}}}}
{i,c} => return{i*bulk + promote{u64, ctz_ext{hom_to_int_ext{c}}}}
}
@muLoop{bulk, tern{arch_defvw>=256, 1, 2}}(x in tup{VT,*E~~x}, M in 'm' over is to n) {
eq:= each{OP, x}
@ -95,13 +95,13 @@ def bittab_selector{loadtab} = {
top := hi4 + VI~~((arch_vec{u32}~~(x&~low))>>3)
byte:= sel{[16]i8, t0, hi4^top} | sel{[16]i8, t1, top}
mask:= sel{[16]i8, b, x & low}
homMask{(mask & byte) == mask}
hom_to_int{(mask & byte) == mask}
}
def selector{x:([16]i8) if hasarch{'AARCH64'}} = {
byte:= [16]u8~~sel{tup{t0,t1}, ty_u{x}>>3}
mask:= [16]u8**1 << ty_u{x & low}
res:= homMask{(mask & byte) == mask}
res:= hom_to_int{(mask & byte) == mask}
}
def reload{} = { tup{t0,t1} = loadtab{} }
@ -112,11 +112,11 @@ def readbytes{vtab}{} = {
def [k]_ = VI; def l = 128/k
def side{i} = {
def U = arch_vec{ty_u{k}}
def m = @collect (vtab over _ from i to i+l) homMask{vtab} # TODO multi-value homMask
def m = @collect (vtab over _ from i to i+l) hom_to_int{vtab} # TODO multi-value hom_to_int
VI~~make{U, if (vcount{U}>l) merge{m,m} else m}
}
def side{i if hasarch{'AARCH64'}} = {
def m = each{homMask, split{4, @collect (vtab over _ from i to i+l) vtab}}
def m = each{hom_to_int, split{4, @collect (vtab over _ from i to i+l) vtab}}
VI~~make{[2]u64, m}
}
each{side, l*iota{2}}
@ -220,8 +220,8 @@ def do_bittab{x0:(*void), n:(u64), tab:(*void), u:(u8), t, mode, r0} = {
# Filter out values equal to the previous, or first new
def pind = (iota{k}&15) - 1
prev:= make{VI, each{max{0, .}, pind}}
e:= ~homMask{v == VI**TI~~xi}
e&= base{2,pind<0} | ~homMask{v == sel{[16]i8, v, prev}}
e:= ~hom_to_int{v == VI**TI~~xi}
e&= base{2,pind<0} | ~hom_to_int{v == sel{[16]i8, v, prev}}
if (rbit) rv&= e | -m # Don't remove first bit
m&= e
while (m != 0) {

View File

@ -279,14 +279,14 @@ exportT{'si_select_tab', join{table{select_fn,
@maskedLoop{16}(cw0 in w, r in *u16~~r0, M in 'm' over i to wl) {
def cw = ty_u{wrapChk{cw0, xlf, M}}
def byte = shuf{[16]u8, xrev, cw>>3}
r = homMask{ty_s{byte << (cw & VU**7)} < VI**0}
r = hom_to_int{ty_s{byte << (cw & VU**7)} < VI**0}
}
} else {
if (wl>32 and xl<=16) {
xb:= shuf{[4]u64, spreadBits{[32]u8, load{*u32~~x0}}, 0,1,0,1}
@maskedLoop{32}(cw0 in w, sr in *u32~~r0, M in 'm' over wl) {
cw:= wrapChk{cw0, xlf, M}
sr = homMask{shuf{[16]i8, xb, cw}}
sr = hom_to_int{shuf{[16]i8, xb, cw}}
}
} else {
x:= shuf{[4]u64, load{*VI ~~ x0}, 0,1,0,1}
@ -296,7 +296,7 @@ exportT{'si_select_tab', join{table{select_fn,
cw:= wrapChk{cw0, xlf, M}
byte:= shuf{[16]i8, x, VI~~(([8]u32~~(cw&~low))>>3)}
mask:= shuf{[16]i8, b, cw & low}
sr = homMask{(mask & byte) == mask}
sr = hom_to_int{(mask & byte) == mask}
}
}
}

View File

@ -42,19 +42,19 @@ def rcpE{a:([4]f32)} = emit{[4]f32, '_mm_rcp_ps', a}
# mask stuff
def andAllZero{x:T, y:T if w128i{T}} = all_hom{(x & y) == T**0}
def topMask{x:T if w128{T, 8}} = emit{u16, '_mm_movemask_epi8', x}
def topMask{x:T if w128{T, 16}} = topMask{packs{[8]i16~~x, [8]i16**0}}
def topMask{x:T if w128{T, 32}} = emit{u8, '_mm_movemask_ps', v2f{x}}
def topMask{x:T if w128{T, 64}} = emit{u8, '_mm_movemask_pd', v2d{x}}
def homMask{x:T if w128{T}} = topMask{x}
def homMaskX{a:[_]T if width{T}==16} = tup{2, homMask{re_el{u8,a}}}
def homMask{a:T, b:T if w128i{T,16}} = homMask{packs{ty_s{a},ty_s{b}}}
def top_to_int{x:T if w128{T, 8}} = emit{u16, '_mm_movemask_epi8', x}
def top_to_int{x:T if w128{T, 16}} = top_to_int{packs{[8]i16~~x, [8]i16**0}}
def top_to_int{x:T if w128{T, 32}} = emit{u8, '_mm_movemask_ps', v2f{x}}
def top_to_int{x:T if w128{T, 64}} = emit{u8, '_mm_movemask_pd', v2d{x}}
def hom_to_int{x:T if w128{T}} = top_to_int{x}
def hom_to_int_ext{a:[_]T if width{T}==16} = tup{2, hom_to_int{re_el{u8,a}}}
def hom_to_int{a:T, b:T if w128i{T,16}} = hom_to_int{packs{ty_s{a},ty_s{b}}}
def any_hom{x:T if w128i{T}} = homMask{[16]u8 ~~ x} != 0
def all_hom{x:T if w128i{T}} = homMask{[16]u8 ~~ x} == 0xffff
def any_hom{x:T if w128i{T}} = hom_to_int{[16]u8 ~~ x} != 0
def all_hom{x:T if w128i{T}} = hom_to_int{[16]u8 ~~ x} == 0xffff
def any_top{x:T if w128i{T}} = topMask{x} != 0
def all_top{x:T=[k]_ if w128i{T}} = topMask{x} == (1<<k)-1
def any_top{x:T if w128i{T}} = top_to_int{x} != 0
def all_top{x:T=[k]_ if w128i{T}} = top_to_int{x} == (1<<k)-1
def any_top{x:T if w128i{T, 16}} = any_hom{[8]i16~~x < [8]i16**0}
def all_top{x:T if w128i{T, 16}} = all_hom{[8]i16~~x < [8]i16**0}