Skip to content

Commit 1c2c074

Browse files
authored
[AMD][GLUON] Fix wmma_scaled instr shape verfication (#8425)
Compare instr shape against list instead of tuple. Also refresh the docs for more data types.
1 parent 2726dac commit 1c2c074

File tree

1 file changed

+8
-9
lines changed
  • python/triton/experimental/gluon/language/amd/gfx1250

1 file changed

+8
-9
lines changed

python/triton/experimental/gluon/language/amd/gfx1250/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,26 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
3838
Args:
3939
a (tensor): The operand A to be multiplied.
4040
a_scale (tensor): Scale factor for operand A.
41-
a_format (str): Format of the operand A. Available formats: `e2m1'.
41+
a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`.
4242
b (tensor): The operand B to be multiplied.
4343
b_scale (tensor): Scale factor for operand B.
44-
b_format (str): Format of the operand B. Available formats: `e2m1'.
44+
b_format (str): Format of the operand B. Available formats: `e2m1`, `e4m3`, `e5m2`.
4545
acc (tensor): Accumulator tensor.
4646
"""
4747
_verify_wmma(3, a, b, acc)
4848
if a_format.value == "e2m1":
4949
wmma_layout = a.type.layout.parent
50-
assert isinstance(wmma_layout, AMDWMMALayout) and wmma_layout.instr_shape == (16, 16, 64), \
51-
"e2m1 format expects instr_shape to be (16, 16, 64)"
50+
assert isinstance(wmma_layout, AMDWMMALayout) and wmma_layout.instr_shape == [16, 16, 64], \
51+
"e2m1 format expects instr_shape to be [16, 16, 64]"
5252
if b_format.value == "e2m1":
5353
wmma_layout = b.type.layout.parent
54-
assert isinstance(wmma_layout, AMDWMMALayout) and wmma_layout.instr_shape == (16, 16, 64), \
55-
"e2m1 format expects instr_shape to be (16, 16, 64)"
54+
assert isinstance(wmma_layout, AMDWMMALayout) and wmma_layout.instr_shape == [16, 16, 64], \
55+
"e2m1 format expects instr_shape to be [16, 16, 64]"
5656

5757
acc_layout = acc.type.layout
58-
assert isinstance(acc_layout, AMDWMMALayout) and acc_layout.instr_shape == (16, 16, 128), \
59-
"accumulator tensor's layout must be (16, 16, 128)"
58+
assert isinstance(acc_layout, AMDWMMALayout) and acc_layout.instr_shape == [16, 16, 128], \
59+
"accumulator tensor's layout must be [16, 16, 128]"
6060

61-
# TODO: Add more formats
6261
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
6362
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
6463

0 commit comments

Comments
 (0)