Skip to content

Commit 11af53c

Browse files
authored
[AMD][GLUON] Expose buffer ops to gfx1250 (#8532)
Expose `buffer_load` and `buffer_store`, inherited from CDNA3, to gfx1250.
1 parent a6e7434 commit 11af53c

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

python/test/gluon/test_frontend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,11 +2208,10 @@ def buffer_load_store_kernel(x, y):
22082208
ttgl.amd.cdna4.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
22092209

22102210

2211-
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
2212-
def test_buffer_load_store(target):
2211+
def test_buffer_load_store():
22132212
x = MockTensor(ttgl.float32)
22142213
y = MockTensor(ttgl.float32)
2215-
module = run_parser(buffer_load_store_kernel, *make_args(x, y), target=target)
2214+
module = run_parser(buffer_load_store_kernel, *make_args(x, y), target=HIP_TARGET_CDNA3)
22162215

22172216
expecttest.assert_expected_inline(
22182217
anonymize_ir(module.str_nodebug()), """\
@@ -2257,11 +2256,10 @@ def buffer_load_store_with_broadcast_kernel(x, y):
22572256
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
22582257

22592258

2260-
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
2261-
def test_buffer_load_store_with_broadcast(target):
2259+
def test_buffer_load_store_with_broadcast():
22622260
x = MockTensor(ttgl.float16)
22632261
y = MockTensor(ttgl.float16)
2264-
module = run_parser(buffer_load_store_with_broadcast_kernel, *make_args(x, y), target=target)
2262+
module = run_parser(buffer_load_store_with_broadcast_kernel, *make_args(x, y), target=HIP_TARGET_CDNA3)
22652263

22662264
expecttest.assert_expected_inline(
22672265
anonymize_ir(module.str_nodebug()), """\

python/triton/experimental/gluon/language/amd/gfx1250/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from ..._core import builtin, _unwrap_if_constexpr
2-
from .._ops import _wmma, _verify_wmma
2+
from .._ops import _wmma, _verify_wmma, _mma_scaled
33
from .._layouts import AMDWMMALayout
4-
from .._ops import _mma_scaled
4+
from ..cdna3 import buffer_load, buffer_store
55
from . import tdm
66

7-
__all__ = ["tdm", "wmma", "wmma_scaled", "get_wmma_scale_layout"]
7+
__all__ = ["tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"]
88

99

1010
def _get_wmma_scale_layout(dot_operand_layout, shape, semantic):

0 commit comments

Comments
 (0)