use tail{...} much more
This commit is contained in:
parent
d1f3efe8db
commit
a5c6e3271c
@ -26,4 +26,4 @@ def any_hom{x:T if w256i{T} and elwidth{T}>=32} = hom_to_int{[8]u32 ~~ x} != 0
|
||||
def all_hom{x:T if w256i{T} and elwidth{T}>=32} = hom_to_int{[8]u32 ~~ x} == 0xff
|
||||
|
||||
def any_top{x:T=[_]E if w256i{T} and width{E}>=32} = top_to_int{x} != 0
|
||||
def all_top{x:T=[k]E if w256i{T} and width{E}>=32} = top_to_int{x} == (1<<k)-1
|
||||
def all_top{x:T=[k]E if w256i{T} and width{E}>=32} = top_to_int{x} == tail{k}
|
||||
|
||||
@ -25,7 +25,7 @@ def top_to_int{x:T if w256{T, 16}} = {
|
||||
def any_hom{x:T if w256i{T}} = ~emit{u1, '_mm256_testz_si256', v2i{x}, v2i{x}}
|
||||
def all_hom{x:T if w256i{T}} = hom_to_int{[32]u8 ~~ x} == 0xffff_ffff
|
||||
def any_top{x:T if w256i{T}} = top_to_int{x} != 0
|
||||
def all_top{x:T=[k]_ if w256i{T}} = top_to_int{x} == (1<<k)-1
|
||||
def all_top{x:T=[k]_ if w256i{T}} = top_to_int{x} == tail{k}
|
||||
def hom_to_int{a:T, b:T if w256i{T,16}} = hom_to_int{vec_shuffle{[4]u64, packs128{ty_s{a},ty_s{b}}, 0,2,1,3}}
|
||||
|
||||
def any_top{x:T if w256i{T,32}} = ~emit{u1, '_mm256_testz_ps', v2f{x}, v2f{x}}
|
||||
|
||||
@ -97,9 +97,9 @@ def broadcast{n, v if knum{n}} = each{{_}=>v, range{n}}
|
||||
|
||||
# type stats
|
||||
def minvalue{T if isunsigned{T}} = 0
|
||||
def maxvalue{T if isunsigned{T}} = (1<<width{T})-1
|
||||
def maxvalue{T if isunsigned{T}} = tail{width{T}}
|
||||
def minvalue{T if issigned{T}} = - (1<<(width{T}-1))
|
||||
def maxvalue{T if issigned{T}} = (1<<(width{T}-1))-1
|
||||
def maxvalue{T if issigned{T}} = tail{width{T}-1}
|
||||
|
||||
# vector type checks (all disallow u1 element)
|
||||
def genchk{W, F} = match {
|
||||
@ -335,6 +335,8 @@ def tail{n} = (1<<n) - 1 # mask of the n least significant bits
|
||||
def zlow{n,x} = (x >> n) << n # zero out n least significant bits
|
||||
def tail{n,x} = x & tail{n} # get the n least significant bits
|
||||
def bit {k,x} = x & (1<<k) # get the k-th bit
|
||||
def tail{R,n:N if primt{R} and primt{N}} = { assert{n<width{R}}; ((R~~1)<<n) - 1 }
|
||||
def tail{R,n if primt{R} and knum{n}} = R~~((1<<n) - 1)
|
||||
|
||||
def bit_lut{bits, idx if length{bits}<64 and all{(bits&1) == bits}} = ((base{2,bits} >> idx) & 1) != 0
|
||||
|
||||
|
||||
@ -148,7 +148,7 @@ def bitalign{{2,8,s}, 8, G if hasarch{'AVX512VBMI'}} = G{s, {a:V=[k](u8)} => {
|
||||
def cyc = make{V, cycle{k, range{8}}}
|
||||
def b = sel{V, a, cyc + muls{make{V, replicate{8, range{k/8}}}}}
|
||||
def c = multishift{re_el{u64,b}, muls{cyc}}
|
||||
c & V**cast_i{u8, tail{s}}
|
||||
c & V**tail{u8,s}
|
||||
}}
|
||||
|
||||
|
||||
@ -162,7 +162,7 @@ def bitalign{{2,8,s}, 8, G if hasarch{'AARCH64'}} = G{s, {a:V==[16]u8} => {
|
||||
|
||||
def r0 = sel{[16]u8, a, shuf1} << shift1
|
||||
def r1 = sel{[16]u8, a, shuf2} << shift2
|
||||
(r0 | r1) & V**cast_i{u8, tail{s}}
|
||||
(r0 | r1) & V**tail{u8,s}
|
||||
}}
|
||||
|
||||
def bitalign{8, {2,8,d}, G if hasarch{'AARCH64'}} = {
|
||||
@ -200,7 +200,7 @@ def bitalign{8, {2,8,d}, G if hasarch{'AARCH64'}} = {
|
||||
def run{do_blend}{a:V==[16]u8} = {
|
||||
def shuf1 = shuf0 + V**1
|
||||
def shift1 = shift0 + [16]i8**cast_i{i8,d}
|
||||
def b = a & V**cast_i{u8, tail{d}}
|
||||
def b = a & V**tail{u8,d}
|
||||
def r0 = sel{[16]u8, b, shuf0} << shift0
|
||||
def r1 = sel{[16]u8, b, shuf1} << shift1
|
||||
def r01 = r0 | r1
|
||||
|
||||
@ -51,7 +51,7 @@ def store_bits{sz, x:(*u64), n:(ux), v} = match (sz) {
|
||||
am:u64 = 64/sz
|
||||
w:u64 = load{x,n/am}
|
||||
sh:u64 = (n&(am-1)) * sz
|
||||
w&= ~(u64~~tail{sz} << sh)
|
||||
w&= ~(tail{u64,sz} << sh)
|
||||
w|= (vc<<sh)
|
||||
store{x, n/am, w}
|
||||
}
|
||||
|
||||
@ -79,11 +79,11 @@ fn count{T if T<=i16}(tab:*u16, ov:*u16, xp:*void, n:u64, min_allowed:T) : T = {
|
||||
fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = {
|
||||
def vl = arch_defvw/16
|
||||
def V = [vl]u16
|
||||
def bot = 1<<15 - 1
|
||||
def bot = tail{15}
|
||||
on:usz = 0
|
||||
@for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{any_top{t}}) {
|
||||
o := if (hasarch{'X86_64'}) top_to_int{t} else hom_to_int{t > V**bot}
|
||||
if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1
|
||||
if (jv == n/vl) o &= tail{type{o}, n%vl}
|
||||
while (o > 0) {
|
||||
jv := jv*vl + cast_i{usz, ctz{o}}
|
||||
store{tab, jv, load{tab, jv} & bot}
|
||||
|
||||
@ -105,7 +105,7 @@ def extract_column_pow2{T, x0, r0, nv, k} = {
|
||||
def f = tree_fold{unzip0{32}, .}
|
||||
def proc{hx} = {
|
||||
ri := D~~f{hx}
|
||||
top := D**(1<<15); m := D**(1<<16 - 1)
|
||||
top := D**(1<<15); m := D**tail{16}
|
||||
(ri & m) | (D~~(ri&top == top) &~ m)
|
||||
}
|
||||
r = V~~packs128{each{proc, split{k/2, xs}}}
|
||||
@ -128,7 +128,7 @@ def extract_column_modperm{x0, r0, nv, l, el, vl} = {
|
||||
e := p2 + promote{ux,el} # Absorb into element size for most computation
|
||||
l8 := cast_i{u8, l}
|
||||
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
|
||||
elo:= V**(u8~~1<<e - 1)
|
||||
elo:= V**tail{u8, e}
|
||||
ie := iota{V} & elo
|
||||
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
|
||||
def mu16{k} = {
|
||||
@ -204,7 +204,7 @@ def extract_column_modperm{x0, r0, nv, l, el, vl} = {
|
||||
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not vect{T}} = x << V**cast_i{u8,-sh}
|
||||
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not vect{T}} = x << V**cast_i{u8, sh}
|
||||
def uz_lane = {
|
||||
l := V**(u8~~1<<el - 1)
|
||||
l := V**tail{u8, el}
|
||||
h := V**(16 - u8~~16>>p2)
|
||||
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
|
||||
dz |= (il &~ (l | h))<<p2 # middle->high
|
||||
@ -222,7 +222,7 @@ def extract_column_modperm{x0, r0, nv, l, el, vl} = {
|
||||
shuf{[8]u32, ., cr}
|
||||
}
|
||||
# Run, writing every 1<<p2 steps
|
||||
plo := usz~~1<<p2 - 1
|
||||
plo := tail{usz, p2}
|
||||
loop{{r, i, write} => {
|
||||
def ra = add_res{r}
|
||||
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
|
||||
@ -274,7 +274,7 @@ def fold_rows_bit_lt64{
|
||||
# repeated shift-or-mask
|
||||
def step{x, {m, sh}} = { def a=x&m; a|a>>sh }
|
||||
def ss = 1<<iota{5}
|
||||
def ms = (1<<64 - 1) / (1 + 1<<ss)
|
||||
def ms = tail{64} / (1 + 1<<ss)
|
||||
fold{step, ., each{tup, ms, ss}}
|
||||
}
|
||||
run_loop2{{op} => loop_T{u32, {x} => extract{op{x, x>>1}}}}
|
||||
@ -351,7 +351,7 @@ def fold_rows_bit_lt64{
|
||||
assert{l > 4}
|
||||
ld:= l-1; lld:= l*ld
|
||||
{mult0, _} := unaligned_spaced_mask_mod{ld}
|
||||
mult0 &= u64~~1<<lld - 1
|
||||
mult0 &= tail{u64, lld}
|
||||
{mult1, _} := unaligned_spaced_mask_mod{lld}
|
||||
ll:= l*l
|
||||
{tk, tkd} := unaligned_spaced_mask_mod{ll}; tk <<= tkd
|
||||
@ -359,7 +359,7 @@ def fold_rows_bit_lt64{
|
||||
loop{tup{mult0,mult1}, tup{topk}}
|
||||
} else {
|
||||
{mult, _} := unaligned_spaced_mask_mod{l-1}
|
||||
if (l==8) mult &= 1<<(7*8) - 1
|
||||
if (l==8) mult &= tail{7*8}
|
||||
loop{tup{mult}, tup{}}
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ def hash_alloc{logsz, msz, ext, Ts, v0s, has_radix, ordered} = {
|
||||
each{memset{., ., sze}, ptrs, v0s}
|
||||
|
||||
def hash_resize{cc, m} = {
|
||||
dif := sz*((1<<m)-1) # Number of elements of space to add
|
||||
dif := sz*tail{m} # Number of elements of space to add
|
||||
sh -= m; sz <<= m
|
||||
set_thresh{}
|
||||
cc = 0 # Collision counter
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
local def bit_mask_init{w} = {
|
||||
apply{merge, each{{x} => {
|
||||
merge{(w/8-1)**255, (1<<x)-1, (w/8)**0}
|
||||
merge{(w/8-1)**255, tail{x}, (w/8)**0}
|
||||
}, iota{8}}}
|
||||
}
|
||||
mask256_1:*u8 = bit_mask_init{256}; def mask_of_first_bits{T,n if width{T}==256} = load{*[32]u8 ~~ (mask256_1 + (n>>3)^31 + 64*(n&7))}
|
||||
@ -37,7 +37,7 @@ def mask_first{n} = {
|
||||
def mask{'count'} = n
|
||||
def mask{{x}} = tup{mask{x}}
|
||||
def mask{x:X if vect{X}} = x & (X~~mask_of_first{X,n})
|
||||
def mask{x:X if any_int{x}} = x & ((1<<n) - 1)
|
||||
def mask{x:X if any_int{x}} = tail{n, x}
|
||||
def mask{0} = 1
|
||||
}
|
||||
|
||||
|
||||
@ -650,7 +650,7 @@ def rep_const_bool_odd_mask4{
|
||||
# Carry: shifting and word-crossing is done on the initial permuted x
|
||||
# No need to carry across input words since they align with output words
|
||||
# First bit of each word in xo below is wrong, but it doesn't matter!
|
||||
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
|
||||
mr := scal{tail{u64, k}} # Mask out carry bit before output
|
||||
def sub_carry{a, c} = match (M) {
|
||||
{[l](u64)} => {
|
||||
ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 }
|
||||
|
||||
@ -264,7 +264,7 @@ fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u
|
||||
def sse{a} = make{[2]u64, a, 0}
|
||||
carry := sse{init}
|
||||
# xor-scan on bytes
|
||||
xmat := V**base{256, 1<<(8-iota{8}) - 1}
|
||||
xmat := V**base{256, tail{8-iota{8}}}
|
||||
def xor8 = emit{V, '_mm512_gf2p8affine_epi64_epi8', ., xmat, 0}
|
||||
# Exclusive xor-scan on one word
|
||||
def exor64 = clmul{., sse{1<<64 - 2}, 0}
|
||||
@ -600,7 +600,7 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
|
||||
i :usz = 0 # row bit index
|
||||
wn:usz = 0 # starting word of next row
|
||||
c:u64 = id # carry
|
||||
def word{bit} = bit * ((1<<64) - 1)
|
||||
def word{bit} = bit * tail{64}
|
||||
we:= nl/64; while (wn < we) {
|
||||
iw:= wn
|
||||
r := res_m1{load{src, iw}, c, u64~~1 << (i%64)}
|
||||
|
||||
@ -276,7 +276,7 @@ fn simd_member_u8(w0:*void, nw:u64, x0:*void, nx:u64, r0:*void, tab:*void) : voi
|
||||
u = fill_bittab(w0, nw, tab, u, -1)
|
||||
|
||||
if (u == 0) { # All found!
|
||||
@for (r in *u64~~r0 over cdiv{nx,64}) r = maxvalue{u64}
|
||||
@for (r in *u64~~r0 over cdiv{nx,64}) r = tail{64}
|
||||
} else {
|
||||
bittab_lookup{x0, nx, r0, tab}
|
||||
}
|
||||
|
||||
@ -223,7 +223,7 @@ fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
|
||||
@for_masked{bulk}(cw0 in tup{VI,w}, sr in tup{'g',r}, M in 'm' over wl) {
|
||||
cw:= wrapChk{cw0, xlf, M}
|
||||
got:= gather{VD**0, x, cw, M}
|
||||
if (TDE!=TD) got&= VD**((1<<wd)-1)
|
||||
if (TDE!=TD) got&= VD**tail{wd}
|
||||
sr{got}
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -23,7 +23,7 @@ itab:*u64 = maketab{8,8} # 256 elts, 2KB; shared by many methods
|
||||
|
||||
# Recover popcount, for when POPCNT isn't there
|
||||
def has_popc = hasarch{'POPCNT'}
|
||||
def tab_popc{i:I, w} = (i>>(width{I}-w) + 1) & (1<<w - 1)
|
||||
def tab_popc{i:I, w} = tail{w, i>>(width{I}-w) + 1}
|
||||
def popc_alt{v, i, w} = if (has_popc) popc{v} else tab_popc{i, w}
|
||||
|
||||
# slash{c, T} defines:
|
||||
@ -289,7 +289,7 @@ def pext_popc{x:T, m:T} = {
|
||||
tup{ x - (x>>1 & z0), zm + z0 }
|
||||
} else if (hasarch{'AVX2'} and vect{T} and k >= 32) {
|
||||
# We have variable shifts at these sizes
|
||||
lh := scal{low_s*(1<<h - 1)}
|
||||
lh := scal{low_s*tail{h}}
|
||||
zl := z & lh
|
||||
def S = re_el{ty_u{k}, T}
|
||||
tup{T~~(S~~(x&~lh) >> S~~zl) | (x&lh), T~~(S~~z >> h) + zl}
|
||||
@ -303,7 +303,7 @@ def pext_popc{x:T, m:T} = {
|
||||
# Shift high x group down by low z, then add halves of z
|
||||
odd:T = scal{low_s*(1<<k - 1<<h)} # Top half
|
||||
ze := z&~odd
|
||||
z1 := ze + scal{low_s*(1<<(k-1) - 1)} # z-1, as signed k-bit
|
||||
z1 := ze + scal{low_s*tail{k-1}} # z-1, as signed k-bit
|
||||
move := odd &~ (z1<<1) # Only groups where z>0 move
|
||||
tup{
|
||||
(x&~move) | shift{1, z1, x&move}>>1,
|
||||
@ -314,7 +314,7 @@ def pext_popc{x:T, m:T} = {
|
||||
# Compose k/g groups with k/g-1 regular shifts
|
||||
def multi_shift{x, z, g, k, sc} = {
|
||||
o := z * sc{lowbits{k,g}} # Offsets by prefix sum
|
||||
def s = 1<<g - 1
|
||||
def s = tail{g}
|
||||
def s0 = sc{s}
|
||||
def oo{sh} = if (sh==g) z else o>>(sh-g) # Offset for group
|
||||
def gr{sh} = (x & sc{s<<sh}) >> (oo{sh} & s0) # Shifted group
|
||||
@ -341,7 +341,7 @@ def pext_popc{x0:V, m0:V if hasarch{'PCLMUL'} and V==[2]u64} = {
|
||||
m := m0
|
||||
x := x0 & m
|
||||
d := ~m << 1 # One bit of the position difference at x
|
||||
c := V**(1<<64-1)
|
||||
c := V**tail{64}
|
||||
@unroll (i to lb{scalwidth{V}}) {
|
||||
def sh = 1 << i
|
||||
def shift_at{v, s} = { v = (v&~s) | (v&s)>>sh }
|
||||
@ -400,6 +400,6 @@ export{'si_thresh_compress_bool', u64~~thresh_bool{}}
|
||||
# Mask i is the smallest possible mask containing 1 every i bits:
|
||||
# there would be a 1 just past the top bit.
|
||||
# The number of trailing zeros is 64%i , and the popcount is 64/i .
|
||||
def get_spaced_masks{i} = (1<<64 - 1<<(64%i)) / (1<<i - 1)
|
||||
def get_spaced_masks{i} = (1<<64 - 1<<(64%i)) / tail{i}
|
||||
spaced_masks:*u64 = get_spaced_masks{1 + iota{64}}
|
||||
export{'si_spaced_masks', spaced_masks}
|
||||
|
||||
@ -37,7 +37,7 @@ def any_hom{x:T if w128i{T}} = hom_to_int{[16]u8 ~~ x} != 0
|
||||
def all_hom{x:T if w128i{T}} = hom_to_int{[16]u8 ~~ x} == 0xffff
|
||||
|
||||
def any_top{x:T if w128i{T}} = top_to_int{x} != 0
|
||||
def all_top{x:T=[k]_ if w128i{T}} = top_to_int{x} == (1<<k)-1
|
||||
def all_top{x:T=[k]_ if w128i{T}} = top_to_int{x} == tail{k}
|
||||
def any_top{x:T if w128i{T, 16}} = any_hom{[8]i16~~x < [8]i16**0}
|
||||
def all_top{x:T if w128i{T, 16}} = all_hom{[8]i16~~x < [8]i16**0}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user