uCBQN/src/singeli/src/lut.singeli
2024-07-28 02:47:12 +03:00

198 lines
8.5 KiB
Plaintext

def __shl{(u16)}{a:T, b} = T~~(re_el{u16,a}<<b) # for x86's lack of u8 shift
def broadcast{[(n*2)]E, x:[n]E} = pair{x, x}
def pow2_up{v, least} = max{least, 1<<ceil_log2{v}} # least ⌈ ⌈⌾(2⊸⋆⁼) v
# make a LUT of at least nt elements in tab, to be indexed by [ni_real≥ni]u8
# E must be unsigned
# mode is a hint on expected usage:
# mode=='i': same table is reused for many index batches (list ⊏ list)
# mode=='c': same index is reused across many LUTs (list⊸⊏˘ mat)
# mode=='o': neither table nor indices are reused (mat ⊏˘ mat)
# returns {nt_real≥nt, ni_real≥ni, {tab:*E} => {is:[ni_real]u8} => (tuple of vectors, totalling to ni_real elements)}
def lut_gen{mode, E, nt, ni} = 0
local def loader{G} = {
def proc{mem:*_} = {
def loader_mtg{V, i} = load{*V~~mem, i}
def loader_mtg{'offset', E, en} = proc{en + *E~~mem}
}
def proc{TG if kgen{TG}} = TG
def proc{vs if ktup{vs}} = {
def S = oneType{vs}
def loader_vtg{Q, i} = {
if (width{Q} == width{S}) Q~~select{vs,i}
else if (width{Q}*2 == width{S}) Q~~half{select{vs, i>>1}, i&1}
else assert{0, S, Q}
}
def loader_vtg{'offset', E, en} = {
def off = (width{E} * en) / width{S}
assert{off != length{vs}, vs, E, en}
if (off==0.5) proc{tup{half{select{vs,0},1}}}
else proc{slice{vs, off}}
}
}
def load_accepter{...vs} = G{proc{...vs}}
}
def unzip_load{E, n, TG} = each{merge, unzip_load{E, n/2, TG}, unzip_load{E, n/2, TG{'offset', E, n}}}
def unzip_load{E, n, TG if width{E}*n <= arch_defvw} = each{tup, unzip{TG{[n]E, 0}, TG{[n]E, 1}}}
def widen_tup{u32, is:([16]u8)} = tup{ # compiler will deduplicate all the repeated calls of this on the same is
widen{[8]u32, is},
widen{[8]u32, shuf{[4]u32, is, 4b3232}},
}
def blend_halves{mode, E, nt, ni} = tup{nt, ni, loader{{TG} => {
def nth = nt/2
def {(nth), (ni), prev} = lut_gen{mode, E, nth, ni}
def lo = prev{TG}
def hi = prev{TG{'offset', E, nth}}
def me{'raw'} = tup{lo, hi}
def me{'xor', it} = each{{a,b} => a{'xor',b}, me{'raw'}, it{'raw'}}
# def me{'is', is:[_](u8)} = lo{'is', is}
# def me{...is} = {
# def [_]IE = oneType{is}
# def shl = if (IE==u8) __shl{u16} else __shl
# def bm = shl{is, width{IE}-1 - lb{nth}}
# each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}}
# }
def me{is:[_](u8) if hasarch{'X86_64'} and E==u32} = {
each{{l,h,m} => topBlend{l,h,m}, lo{is}, hi{is}, each{{c} => c << (31 - lb{nth}), widen_tup{E, is}}}
}
def me{is:[_](u8) if hasarch{'X86_64'} and E==u8} = {
def bm = is <<{u16} (7 - lb{nth}) # TODO for outermost bit could do a cmpgt, increasing port diversity
each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}}
}
# def me{is:[_](u8) if hasarch{'AARCH64'} and E==u8} = { # only for one outermost blend
# def end = type{is}**(nth-1)
# def bm = is > end
# each{{l,h} => homBlend{l,h,bm}, lo{is & end}, hi{is & end}}
# }
# TODO xor-ing could still be worth it for lower repeated levels where the index transformation can be deduped; and outermost can do a cmpgt
# if (mode=='c') hi{'xor', lo}
# def me{is:I=[_](u8) if mode=='c'} = {
# def bm = (is & I**lb{nth})
# each{{l,h} => topBlend{l,h,bm}, lo{is}, hi{is}}
# }
}}}
def raw_widen_inds{[k]D, x:[k0]S if k0>=k} = { # : [k*sc]S
def sc = width{D} / width{S}
def RV = [k*sc]S
def add = make{RV, range{k*sc} % sc}
if (hasarch{'AARCH64'} and [k]D == [2]u64 and S==u8) {
def x16 = undefPromote{[16]u8, x}
(RV~~sel{[16]u8, x16, make{[16]u8, range{16}>>3}}<<3) + add
} else if (hasarch{'AVX2'} and [k]D == [4]u64 and S==u32) {
(sel{[8]u32, undefPromote{[8]u32, x}, make{[8]u32, range{8}>>1}}<<sc) + add
} else {
def wd = widen{[k]D, x}
re_el{S, wd * [k]D**base{1<<width{S}, sc**sc}} + add
}
}
def raw_widen_inds{k, sc, x:[_](u8)} = raw_widen_inds{[k]primtype{'u', 8<<sc}, x}
def widen_inds{mode, E, nt0, ni0, sc} = match(lut_gen{mode, primtype{'u',width{E}/sc}, nt0*sc, ni0*sc}) { # e.g. sc==2: {a,b,c,d}[w,x,y,z] → {a0,a1, b0,b1, c0,c1, d0,d1}[w*2,w*2+1, x*2,x*2+1, y*2,y*2+1, z*2zw*2+1]
{{nt1, ni1, G}} => tup{nt1/sc, ni1/sc, loader{{TG} => {
def prev = G{TG}
def ni = ni1/sc
def WV = [ni]primtype{'u', 8*sc}
{is:([ni]u8)} => {
each{re_el{E,.}, prev{raw_widen_inds{WV, is}}}
}
}}}
{x} => x
}
def zip_halves{mode, E, nt, ni} = match(lut_gen{mode, w_h{E}, nt, ni}) { # e.g. {a,b,c,d}[w,x,y,z] → zip({a0,b0,c0,d0}[w,x,y,z], {a1,b1,c1,d1}[w,x,y,z])
{{nt, ni, G}} => tup{nt, ni, loader{{TG} => {
# show{E, '→', w_h{E}}
def d = unzip_load{w_h{E}, nt, TG}
# lprintf{tup{'x0',d}}
def prevs = each{G, d}
def run_zip{zipper, is} = {
def {lo, hi} = each{{prev}=>prev{is}, prevs}
join{flip{each{zipper, lo, hi}}}
}
def me{is:([ni]u8)} = {
run_zip{mzip, is}
}
def me{is:([ni]u8) if hasarch{'AARCH64'} and E==u32 and ni==16} = {
def is2 = sel{[16]u8, is, make{[16]u8, 0,1,2,3, 8,9,10,11, 4,5,6,7, 12,13,14,15}}
run_zip{mzip128, is2}
}
def me{is:([ni]u8) if hasarch{'AARCH64'} and E==u64 and ni==16} = {
def is2 = sel{[16]u8, is, make{[16]u8, 0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15}}
run_zip{mzip128, is2}
}
def me{is:([ni]u8) if hasarch{'X86_64'} and E==u64 and ni==16} = {
def is2 = sel{[16]u8, is, make{[16]u8, 0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15}}
run_zip{mzip128, is2}
}
def me{is:([ni]u8) if hasarch{'X86_64'} and E==u16 and ni==32} = {
def is2 = shuf{[4]u64, is, 4b3120}
run_zip{mzip128, is2}
}
}}}
{x} => x
}
# lut_gen order is very important!
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=128 and ni<=32} = blend_halves{mode, E, 128, 32} # TODO probably don't
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=64 and ni<=32} = blend_halves{mode, E, 64, 32} # TODO probably don't
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=32 and ni<=32} = blend_halves{mode, E, 32, 32}
# generate inds to utilize top bit of pshufb zeroing to replace vpblendvb with with vpor
def lut_gen{'c', E==u8, nt, ni if hasarch{'AVX2'} and nt<=64 and ni<=32} = { def vn=pow2_up{nt,16}/16; tup{vn*16, 32, loader{{TG} => {
def luts = each{{i} => [32]u8**TG{[16]u8, i}, range{vn}}
{is:([32]u8)} => {
def bi = range{ceil_log2{vn}}
def bits = each{{o} => is <<{u16} (3-o), bi} # extract bits 0,1,2,3 (as many as needed) from 2b3210xxxx into top bit (x being bits used by pshufb)
def top = [32]u8**128
def isp = each{{i} => is | (top &~ tree_fold{&, each{{m, o} => if (bit{o,i}!=0) m else ~m, bits, bi}}), range{vn}}
tup{tree_fold{|, each{sel{[16]u8,.,.}, luts, isp}}}
}
}}}}
def lut_gen{mode, E==u8, nt, ni if hasarch{'AVX2'} and nt<=16 and ni<=32} = tup{16, 32, loader{{TG} => {
lut:[32]u8 = [32]u8**TG{[16]u8, 0}
{is:([32]u8)} => tup{sel{[16]u8, lut, is}}
}}}
def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=32 and ni<=16} = blend_halves{mode, E, 32, 16} # TODO probably don't
def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=16 and ni<=16} = blend_halves{mode, E, 16, 16}
def lut_gen{mode, E==u32, nt, ni if hasarch{'AVX2'} and nt<=8 and ni<=16} = tup{8, 16, loader{{TG} => {
def lut = TG{[8]u32, 0}
def me{'idxs', is:([16]u8)} = each{{wis} => tup{wis, sel{[8]u32, lut, wis}}, widen_tup{u32,is}} # TODO inline, or properly outline
def me{is:([16]u8)} = each{{{_,v}}=>v, me{'idxs', is}}
}}}
# def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=128 and ni<=16} = blend_halves{mode, E, 128, 16}
def lut_gen{mode, E==u8, nt, ni if hasarch{'AARCH64'} and nt<=16*4 and ni<=16} = { def vn=pow2_up{nt,16}/16; tup{vn*16, 16, loader{{TG} => { # TODO could maybe accept nt==48
def lut = each{TG{[16]u8, .}, range{vn}}
{is:([16]u8)} => tup{sel{lut, is}}
}}}}
def lut_gen{mode, E, nt, ni if mode=='i' and hasarch{'AVX2'} and (E==u16 or E==u64)} = zip_halves{mode, E, nt, ni}
# def lut_gen{mode, E, nt, ni if mode=='c' and hasarch{'AVX2'} and E==u16} = widen_inds{mode, E, nt, max{ni,16}, 2}
# def lut_gen{mode, E, nt, ni if mode=='c' and hasarch{'AVX2'} and E==u64} = widen_inds{mode, E, nt, ni, 2}
# def lut_gen{mode, E, nt, ni if mode=='c' and E==u64} = zip_halves{mode, E, nt, ni} # widen_inds{mode, E, nt, max{ni,16}, 2}
def lut_gen{mode, E==u64, nt, ni if nt>16 and hasarch{'AVX2'}} = 0
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and (E==u16 or E==u32)} = zip_halves{mode, E, nt, ni}
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and E==u64} = widen_inds{mode, E, nt, ni, 2}
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='c' and E>=u16} = 0
def lut_gen{mode, E, nt, ni if hasarch{'AARCH64'} and mode=='i' and E==u64 and nt>16} = 0