use tail{...} much more

This commit is contained in:
dzaima 2025-04-28 17:23:12 +03:00
parent d1f3efe8db
commit a5c6e3271c
15 changed files with 34 additions and 32 deletions

View File

@ -26,4 +26,4 @@ def any_hom{x:T if w256i{T} and elwidth{T}>=32} = hom_to_int{[8]u32 ~~ x} != 0
def all_hom{x:T if w256i{T} and elwidth{T}>=32} = hom_to_int{[8]u32 ~~ x} == 0xff
def any_top{x:T=[_]E if w256i{T} and width{E}>=32} = top_to_int{x} != 0
def all_top{x:T=[k]E if w256i{T} and width{E}>=32} = top_to_int{x} == (1<<k)-1
def all_top{x:T=[k]E if w256i{T} and width{E}>=32} = top_to_int{x} == tail{k}

View File

@ -25,7 +25,7 @@ def top_to_int{x:T if w256{T, 16}} = {
def any_hom{x:T if w256i{T}} = ~emit{u1, '_mm256_testz_si256', v2i{x}, v2i{x}}
def all_hom{x:T if w256i{T}} = hom_to_int{[32]u8 ~~ x} == 0xffff_ffff
def any_top{x:T if w256i{T}} = top_to_int{x} != 0
def all_top{x:T=[k]_ if w256i{T}} = top_to_int{x} == (1<<k)-1
def all_top{x:T=[k]_ if w256i{T}} = top_to_int{x} == tail{k}
def hom_to_int{a:T, b:T if w256i{T,16}} = hom_to_int{vec_shuffle{[4]u64, packs128{ty_s{a},ty_s{b}}, 0,2,1,3}}
def any_top{x:T if w256i{T,32}} = ~emit{u1, '_mm256_testz_ps', v2f{x}, v2f{x}}

View File

@ -97,9 +97,9 @@ def broadcast{n, v if knum{n}} = each{{_}=>v, range{n}}
# type stats
def minvalue{T if isunsigned{T}} = 0
def maxvalue{T if isunsigned{T}} = (1<<width{T})-1
def maxvalue{T if isunsigned{T}} = tail{width{T}}
def minvalue{T if issigned{T}} = - (1<<(width{T}-1))
def maxvalue{T if issigned{T}} = (1<<(width{T}-1))-1
def maxvalue{T if issigned{T}} = tail{width{T}-1}
# vector type checks (all disallow u1 element)
def genchk{W, F} = match {
@ -335,6 +335,8 @@ def tail{n} = (1<<n) - 1 # mask of the n least significant bits
def zlow{n,x} = (x >> n) << n # zero out n least significant bits
def tail{n,x} = x & tail{n} # get the n least significant bits
def bit {k,x} = x & (1<<k) # get the k-th bit
def tail{R,n:N if primt{R} and primt{N}} = { assert{n<width{R}}; ((R~~1)<<n) - 1 }
def tail{R,n if primt{R} and knum{n}} = R~~((1<<n) - 1)
def bit_lut{bits, idx if length{bits}<64 and all{(bits&1) == bits}} = ((base{2,bits} >> idx) & 1) != 0

View File

@ -148,7 +148,7 @@ def bitalign{{2,8,s}, 8, G if hasarch{'AVX512VBMI'}} = G{s, {a:V=[k](u8)} => {
def cyc = make{V, cycle{k, range{8}}}
def b = sel{V, a, cyc + muls{make{V, replicate{8, range{k/8}}}}}
def c = multishift{re_el{u64,b}, muls{cyc}}
c & V**cast_i{u8, tail{s}}
c & V**tail{u8,s}
}}
@ -162,7 +162,7 @@ def bitalign{{2,8,s}, 8, G if hasarch{'AARCH64'}} = G{s, {a:V==[16]u8} => {
def r0 = sel{[16]u8, a, shuf1} << shift1
def r1 = sel{[16]u8, a, shuf2} << shift2
(r0 | r1) & V**cast_i{u8, tail{s}}
(r0 | r1) & V**tail{u8,s}
}}
def bitalign{8, {2,8,d}, G if hasarch{'AARCH64'}} = {
@ -200,7 +200,7 @@ def bitalign{8, {2,8,d}, G if hasarch{'AARCH64'}} = {
def run{do_blend}{a:V==[16]u8} = {
def shuf1 = shuf0 + V**1
def shift1 = shift0 + [16]i8**cast_i{i8,d}
def b = a & V**cast_i{u8, tail{d}}
def b = a & V**tail{u8,d}
def r0 = sel{[16]u8, b, shuf0} << shift0
def r1 = sel{[16]u8, b, shuf1} << shift1
def r01 = r0 | r1

View File

@ -51,7 +51,7 @@ def store_bits{sz, x:(*u64), n:(ux), v} = match (sz) {
am:u64 = 64/sz
w:u64 = load{x,n/am}
sh:u64 = (n&(am-1)) * sz
w&= ~(u64~~tail{sz} << sh)
w&= ~(tail{u64,sz} << sh)
w|= (vc<<sh)
store{x, n/am, w}
}

View File

@ -79,11 +79,11 @@ fn count{T if T<=i16}(tab:*u16, ov:*u16, xp:*void, n:u64, min_allowed:T) : T = {
fn flush_counts(tab:*u16, ov:*u16, n:usz) : usz = {
def vl = arch_defvw/16
def V = [vl]u16
def bot = 1<<15 - 1
def bot = tail{15}
on:usz = 0
@for (t in *V~~tab over jv to cdiv{n, vl}) if (rare{any_top{t}}) {
o := if (hasarch{'X86_64'}) top_to_int{t} else hom_to_int{t > V**bot}
if (jv == n/vl) o &= type{o}~~1<<(n%vl) - 1
if (jv == n/vl) o &= tail{type{o}, n%vl}
while (o > 0) {
jv := jv*vl + cast_i{usz, ctz{o}}
store{tab, jv, load{tab, jv} & bot}

View File

@ -105,7 +105,7 @@ def extract_column_pow2{T, x0, r0, nv, k} = {
def f = tree_fold{unzip0{32}, .}
def proc{hx} = {
ri := D~~f{hx}
top := D**(1<<15); m := D**(1<<16 - 1)
top := D**(1<<15); m := D**tail{16}
(ri & m) | (D~~(ri&top == top) &~ m)
}
r = V~~packs128{each{proc, split{k/2, xs}}}
@ -128,7 +128,7 @@ def extract_column_modperm{x0, r0, nv, l, el, vl} = {
e := p2 + promote{ux,el} # Absorb into element size for most computation
l8 := cast_i{u8, l}
li := cast_i{u8, l + 2 * ((l-1) + (l&2))} # Inverse mod vl
elo:= V**(u8~~1<<e - 1)
elo:= V**tail{u8, e}
ie := iota{V} & elo
kmul := make{H, 2*iota{vl/2}} &~ H~~elo
def mu16{k} = {
@ -204,7 +204,7 @@ def extract_column_modperm{x0, r0, nv, l, el, vl} = {
def __shr{x:(V), sh:T if hasarch{'AARCH64'} and not vect{T}} = x << V**cast_i{u8,-sh}
def __shl{x:(V), sh:T if hasarch{'AARCH64'} and not vect{T}} = x << V**cast_i{u8, sh}
def uz_lane = {
l := V**(u8~~1<<el - 1)
l := V**tail{u8, el}
h := V**(16 - u8~~16>>p2)
dz := (il & l) | (h &~ il)>>(4-e) # low, high->middle
dz |= (il &~ (l | h))<<p2 # middle->high
@ -222,7 +222,7 @@ def extract_column_modperm{x0, r0, nv, l, el, vl} = {
shuf{[8]u32, ., cr}
}
# Run, writing every 1<<p2 steps
plo := usz~~1<<p2 - 1
plo := tail{usz, p2}
loop{{r, i, write} => {
def ra = add_res{r}
if ((plo &~ i) == 0) write{cross{uz_lane{ra}}}
@ -274,7 +274,7 @@ def fold_rows_bit_lt64{
# repeated shift-or-mask
def step{x, {m, sh}} = { def a=x&m; a|a>>sh }
def ss = 1<<iota{5}
def ms = (1<<64 - 1) / (1 + 1<<ss)
def ms = tail{64} / (1 + 1<<ss)
fold{step, ., each{tup, ms, ss}}
}
run_loop2{{op} => loop_T{u32, {x} => extract{op{x, x>>1}}}}
@ -351,7 +351,7 @@ def fold_rows_bit_lt64{
assert{l > 4}
ld:= l-1; lld:= l*ld
{mult0, _} := unaligned_spaced_mask_mod{ld}
mult0 &= u64~~1<<lld - 1
mult0 &= tail{u64, lld}
{mult1, _} := unaligned_spaced_mask_mod{lld}
ll:= l*l
{tk, tkd} := unaligned_spaced_mask_mod{ll}; tk <<= tkd
@ -359,7 +359,7 @@ def fold_rows_bit_lt64{
loop{tup{mult0,mult1}, tup{topk}}
} else {
{mult, _} := unaligned_spaced_mask_mod{l-1}
if (l==8) mult &= 1<<(7*8) - 1
if (l==8) mult &= tail{7*8}
loop{tup{mult}, tup{}}
}
}

View File

@ -57,7 +57,7 @@ def hash_alloc{logsz, msz, ext, Ts, v0s, has_radix, ordered} = {
each{memset{., ., sze}, ptrs, v0s}
def hash_resize{cc, m} = {
dif := sz*((1<<m)-1) # Number of elements of space to add
dif := sz*tail{m} # Number of elements of space to add
sh -= m; sz <<= m
set_thresh{}
cc = 0 # Collision counter

View File

@ -1,6 +1,6 @@
local def bit_mask_init{w} = {
apply{merge, each{{x} => {
merge{(w/8-1)**255, (1<<x)-1, (w/8)**0}
merge{(w/8-1)**255, tail{x}, (w/8)**0}
}, iota{8}}}
}
mask256_1:*u8 = bit_mask_init{256}; def mask_of_first_bits{T,n if width{T}==256} = load{*[32]u8 ~~ (mask256_1 + (n>>3)^31 + 64*(n&7))}
@ -37,7 +37,7 @@ def mask_first{n} = {
def mask{'count'} = n
def mask{{x}} = tup{mask{x}}
def mask{x:X if vect{X}} = x & (X~~mask_of_first{X,n})
def mask{x:X if any_int{x}} = x & ((1<<n) - 1)
def mask{x:X if any_int{x}} = tail{n, x}
def mask{0} = 1
}

View File

@ -650,7 +650,7 @@ def rep_const_bool_odd_mask4{
# Carry: shifting and word-crossing is done on the initial permuted x
# No need to carry across input words since they align with output words
# First bit of each word in xo below is wrong, but it doesn't matter!
mr := scal{u64~~1<<k - 1} # Mask out carry bit before output
mr := scal{tail{u64, k}} # Mask out carry bit before output
def sub_carry{a, c} = match (M) {
{[l](u64)} => {
ca := if (hasarch{'SSE4.2'} or hasarch{'AARCH64'}) { def S = [l]i64; S~~c > S**0 }

View File

@ -264,7 +264,7 @@ fn scan_neq{if hasarch{'AVX512BW', 'VPCLMULQDQ', 'GFNI'}}(init:u64, x:*u64, r:*u
def sse{a} = make{[2]u64, a, 0}
carry := sse{init}
# xor-scan on bytes
xmat := V**base{256, 1<<(8-iota{8}) - 1}
xmat := V**base{256, tail{8-iota{8}}}
def xor8 = emit{V, '_mm512_gf2p8affine_epi64_epi8', ., xmat, 0}
# Exclusive xor-scan on one word
def exor64 = clmul{., sse{1<<64 - 2}, 0}
@ -600,7 +600,7 @@ fn scan_rows_andor{id}(src:*u64, dst:*u64, nl:usz, l:usz) : void = {
i :usz = 0 # row bit index
wn:usz = 0 # starting word of next row
c:u64 = id # carry
def word{bit} = bit * ((1<<64) - 1)
def word{bit} = bit * tail{64}
we:= nl/64; while (wn < we) {
iw:= wn
r := res_m1{load{src, iw}, c, u64~~1 << (i%64)}

View File

@ -276,7 +276,7 @@ fn simd_member_u8(w0:*void, nw:u64, x0:*void, nx:u64, r0:*void, tab:*void) : voi
u = fill_bittab(w0, nw, tab, u, -1)
if (u == 0) { # All found!
@for (r in *u64~~r0 over cdiv{nx,64}) r = maxvalue{u64}
@for (r in *u64~~r0 over cdiv{nx,64}) r = tail{64}
} else {
bittab_lookup{x0, nx, r0, tab}
}

View File

@ -223,7 +223,7 @@ fn select_fn{rw, TI, TD}(w0:*void, x0:*void, r0:*void, wl:u64, xl:u64) : u1 = {
@for_masked{bulk}(cw0 in tup{VI,w}, sr in tup{'g',r}, M in 'm' over wl) {
cw:= wrapChk{cw0, xlf, M}
got:= gather{VD**0, x, cw, M}
if (TDE!=TD) got&= VD**((1<<wd)-1)
if (TDE!=TD) got&= VD**tail{wd}
sr{got}
}
} else {

View File

@ -23,7 +23,7 @@ itab:*u64 = maketab{8,8} # 256 elts, 2KB; shared by many methods
# Recover popcount, for when POPCNT isn't there
def has_popc = hasarch{'POPCNT'}
def tab_popc{i:I, w} = (i>>(width{I}-w) + 1) & (1<<w - 1)
def tab_popc{i:I, w} = tail{w, i>>(width{I}-w) + 1}
def popc_alt{v, i, w} = if (has_popc) popc{v} else tab_popc{i, w}
# slash{c, T} defines:
@ -289,7 +289,7 @@ def pext_popc{x:T, m:T} = {
tup{ x - (x>>1 & z0), zm + z0 }
} else if (hasarch{'AVX2'} and vect{T} and k >= 32) {
# We have variable shifts at these sizes
lh := scal{low_s*(1<<h - 1)}
lh := scal{low_s*tail{h}}
zl := z & lh
def S = re_el{ty_u{k}, T}
tup{T~~(S~~(x&~lh) >> S~~zl) | (x&lh), T~~(S~~z >> h) + zl}
@ -303,7 +303,7 @@ def pext_popc{x:T, m:T} = {
# Shift high x group down by low z, then add halves of z
odd:T = scal{low_s*(1<<k - 1<<h)} # Top half
ze := z&~odd
z1 := ze + scal{low_s*(1<<(k-1) - 1)} # z-1, as signed k-bit
z1 := ze + scal{low_s*tail{k-1}} # z-1, as signed k-bit
move := odd &~ (z1<<1) # Only groups where z>0 move
tup{
(x&~move) | shift{1, z1, x&move}>>1,
@ -314,7 +314,7 @@ def pext_popc{x:T, m:T} = {
# Compose k/g groups with k/g-1 regular shifts
def multi_shift{x, z, g, k, sc} = {
o := z * sc{lowbits{k,g}} # Offsets by prefix sum
def s = 1<<g - 1
def s = tail{g}
def s0 = sc{s}
def oo{sh} = if (sh==g) z else o>>(sh-g) # Offset for group
def gr{sh} = (x & sc{s<<sh}) >> (oo{sh} & s0) # Shifted group
@ -341,7 +341,7 @@ def pext_popc{x0:V, m0:V if hasarch{'PCLMUL'} and V==[2]u64} = {
m := m0
x := x0 & m
d := ~m << 1 # One bit of the position difference at x
c := V**(1<<64-1)
c := V**tail{64}
@unroll (i to lb{scalwidth{V}}) {
def sh = 1 << i
def shift_at{v, s} = { v = (v&~s) | (v&s)>>sh }
@ -400,6 +400,6 @@ export{'si_thresh_compress_bool', u64~~thresh_bool{}}
# Mask i is the smallest possible mask containing 1 every i bits:
# there would be a 1 just past the top bit.
# The number of trailing zeros is 64%i , and the popcount is 64/i .
def get_spaced_masks{i} = (1<<64 - 1<<(64%i)) / (1<<i - 1)
def get_spaced_masks{i} = (1<<64 - 1<<(64%i)) / tail{i}
spaced_masks:*u64 = get_spaced_masks{1 + iota{64}}
export{'si_spaced_masks', spaced_masks}

View File

@ -37,7 +37,7 @@ def any_hom{x:T if w128i{T}} = hom_to_int{[16]u8 ~~ x} != 0
def all_hom{x:T if w128i{T}} = hom_to_int{[16]u8 ~~ x} == 0xffff
def any_top{x:T if w128i{T}} = top_to_int{x} != 0
def all_top{x:T=[k]_ if w128i{T}} = top_to_int{x} == (1<<k)-1
def all_top{x:T=[k]_ if w128i{T}} = top_to_int{x} == tail{k}
def any_top{x:T if w128i{T, 16}} = any_hom{[8]i16~~x < [8]i16**0}
def all_top{x:T if w128i{T, 16}} = all_hom{[8]i16~~x < [8]i16**0}