Skip to content

Commit b36e91b

Browse files
peterbell10zwu-2025
authored andcommitted
[Frontend] Assert that TensorDescriptor fields have matching rank (triton-lang#6911)
Currently a rank mismatch gives an error in the driver when creating the tma descriptor, which makes it hard to track back to the original argument. This just makes the stack trace more useful. Also add some basic tensor descriptor stuff to the docs.
1 parent 67e72f8 commit b36e91b

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

docs/python-api/triton.language.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Programming Model
1212
:nosignatures:
1313

1414
tensor
15+
tensor_descriptor
1516
program_id
1617
num_programs
1718

@@ -71,6 +72,9 @@ Memory/Pointer Ops
7172

7273
load
7374
store
75+
make_tensor_descriptor
76+
load_tensor_descriptor
77+
store_tensor_descriptor
7478
make_block_ptr
7579
advance
7680

python/triton/tools/tensor_descriptor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ class TensorDescriptor:
99
strides: List[int]
1010
block_shape: List[int]
1111

12+
def __post_init__(self):
13+
rank = len(self.shape)
14+
assert len(self.strides) == rank, f"rank mismatch: {self}"
15+
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
16+
1217
@staticmethod
1318
def from_tensor(tensor: Any, block_shape: List[int]):
1419
return TensorDescriptor(

0 commit comments

Comments
 (0)