Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/test/unit/language/test_tensor_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_hopper():
# TODO: The use of stmatrix for Blackwell is currently not supported.
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]


@triton.jit
Expand Down Expand Up @@ -1668,4 +1669,5 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_cuda() and is_hopper():
# TODO: The use of stmatrix for Blackwell is currently not supported.
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]
Loading