-
I'm trying to perform an
I was thinking of getting around this combining the bit representations of each vector outside the associative scan, and splitting it into it's component parts inside the scan. In numpy it would look something like the following: def merge(a, b):
stacked = np.stack((a, b), axis=-1)
return np.frombuffer(stacked.tobytes(), dtype=np.float64).reshape(a.shape)
def unmerge(merged):
assert merged.dtype == np.float64
stacked = np.frombuffer(merged.tobytes(), dtype=np.float32).reshape(
*merged.shape, 2
)
return stacked[..., 0], stacked[..., 1] Is there some equivalent way of doing this in triton? Can I access a pointer's address for use in a different pointer with a different dtype? Or am I better off waiting for multi-arg associative scan support? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Ended up solving it: @triton.jit
def bitcast_merge_triton(a, b):
tl.static_assert(a.dtype == tl.float32)
tl.static_assert(b.dtype == tl.float32)
a = a.to(dtype=tl.int32, bitcast=True).to(tl.int64) # directly converted to int32
a = a << 32 # shifted by 32 bits
b = b.to(dtype=tl.int32, bitcast=True).to(tl.int64) # directly converted to int32
return a | b
@triton.jit
def bitcast_unmerge_triton(merged):
tl.static_assert(merged.dtype == tl.int64)
b = (merged & 0xFFFFFFFF).to(tl.int32).to(tl.float32, bitcast=True)
a = (merged >> 32).to(tl.int32).to(tl.float32, bitcast=True) # shifted by 32 bits
return a, b |
Beta Was this translation helpful? Give feedback.
Ended up solving it: