Skip to content

Commit 7ec17cd

Browse files
authored
[AMD][GLUON] Fix none scale value for wmma/mfma (#8427)
Unwrap None from constexpr for wmma/mfma scaled. Add corresponding frontend test.
1 parent da3d437 commit 7ec17cd

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

python/test/gluon/test_frontend.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,29 @@ def kernel():
24452445
""")
24462446

24472447

2448+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
2449+
def test_amd_mfma_scaled_none(target):
2450+
2451+
@gluon.jit
2452+
def kernel():
2453+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(4, [16, 16, 128], True, [1, 1])
2454+
scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([],
2455+
[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]],
2456+
[], [], [16, 4])
2457+
2458+
a = ttgl.full([16, 64], 0x11, ttgl.uint8, ttgl.DotOperandLayout(0, mfma_layout, 16))
2459+
b = ttgl.full([64, 16], 0x22, ttgl.uint8, ttgl.DotOperandLayout(1, mfma_layout, 16))
2460+
2461+
b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout)
2462+
acc = ttgl.full([16, 16], 0, ttgl.float32, mfma_layout)
2463+
ttgl.amd.cdna4.mfma_scaled(a, None, 'e2m1', b, b_scale, 'e2m1', acc)
2464+
2465+
with pytest.raises(CompilationError) as e:
2466+
run_parser(kernel, target=target)
2467+
2468+
assert "Scales must not be None" in str(e.value)
2469+
2470+
24482471
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
24492472
def test_amd_wmma_scaled(target):
24502473

@@ -2497,6 +2520,32 @@ def kernel():
24972520
""")
24982521

24992522

2523+
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
2524+
def test_amd_wmma_scaled_none(target):
2525+
2526+
@gluon.jit
2527+
def kernel():
2528+
wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 128])
2529+
wmma_layout_packed: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 64])
2530+
scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([[0, 1], [0, 2]],
2531+
[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], [], [],
2532+
[16, 4])
2533+
a_layout: ttgl.constexpr = ttgl.DotOperandLayout(0, wmma_layout_packed, 16)
2534+
b_layout: ttgl.constexpr = ttgl.DotOperandLayout(1, wmma_layout_packed, 16)
2535+
2536+
a = ttgl.full([16, 64], 0x11, ttgl.uint8, a_layout)
2537+
b = ttgl.full([64, 16], 0x22, ttgl.uint8, b_layout)
2538+
b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout)
2539+
acc = ttgl.full([16, 16], 0, ttgl.float32, wmma_layout)
2540+
2541+
ttgl.amd.gfx1250.wmma_scaled(a, None, 'e2m1', b, b_scale, 'e2m1', acc)
2542+
2543+
with pytest.raises(CompilationError) as e:
2544+
run_parser(kernel, target=target)
2545+
2546+
assert "Scales must not be None" in str(e.value)
2547+
2548+
25002549
@gluon.jit
25012550
def padded_shared_layout_kernel():
25022551
shape: ttgl.constexpr = [64, 64]

python/triton/experimental/gluon/language/amd/cdna4/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from triton.experimental.gluon.language import _core as ttgl
2-
from ..._core import builtin, float32
2+
from ..._core import builtin, float32, _unwrap_if_constexpr
33
from ..._layouts import DotOperandLayout
44
from .._layouts import AMDMFMALayout
55
from ..cdna3 import _buffer_atomic_rmw_impl
@@ -43,6 +43,8 @@ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
4343
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
4444
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
4545

46+
a_scale = _unwrap_if_constexpr(a_scale)
47+
b_scale = _unwrap_if_constexpr(b_scale)
4648
assert a_scale is not None and b_scale is not None, "Scales must not be None"
4749

4850
tensor = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, False, True, True, float32)

python/triton/experimental/gluon/language/amd/gfx1250/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from ..._core import builtin
1+
from ..._core import builtin, _unwrap_if_constexpr
22
from .._ops import _wmma, _verify_wmma
33
from triton.experimental.gluon.language import _core as ttgl
4-
from triton.experimental.gluon.language._semantic import _check
5-
from ..._layouts import DotOperandLayout
64
from .._layouts import AMDWMMALayout
75
from . import tdm
86

@@ -61,6 +59,8 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
6159
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
6260
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
6361

62+
a_scale = _unwrap_if_constexpr(a_scale)
63+
b_scale = _unwrap_if_constexpr(b_scale)
6464
assert a_scale is not None and b_scale is not None, "Scales must not be None"
6565

6666
handle = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, fast_math=False, lhs_k_pack=True,

0 commit comments

Comments
 (0)