Skip to content
Discussion options

You must be logged in to vote

Ended up solving it:

@triton.jit
def bitcast_merge_triton(a, b):
    tl.static_assert(a.dtype == tl.float32)
    tl.static_assert(b.dtype == tl.float32)
    a = a.to(dtype=tl.int32, bitcast=True).to(tl.int64)  # directly converted to int32
    a = a << 32  # shifted by 32 bits
    b = b.to(dtype=tl.int32, bitcast=True).to(tl.int64)  # directly converted to int32
    return a | b
    
    
@triton.jit
def bitcast_unmerge_triton(merged):
    tl.static_assert(merged.dtype == tl.int64)
    b = (merged & 0xFFFFFFFF).to(tl.int32).to(tl.float32, bitcast=True)
    a = (merged >> 32).to(tl.int32).to(tl.float32, bitcast=True)  # shifted by 32 bits
    return a, b

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jackd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant