Skip to content

Commit 7c59c1d

Browse files
authored
[AMD][GLUON] Expose get wmma/mfma scale layout (#8496)
This PR exposes the internal layout utility `chooseScaledMfmaScaleLayout` and `chooseScaledWmmaScaleLayout` for Gluon, to help generate a linear layout for scale used in `mfma_scaled`/`wmma_scaled`. This also allows gluon kernels to specify a scalar scale value or leave it as None.
1 parent 4d85824 commit 7c59c1d

File tree

6 files changed

+352
-189
lines changed

6 files changed

+352
-189
lines changed

python/src/gluon_ir.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,16 @@ void init_gluon_ir(py::module &&m) {
372372
ctx, version, warpsPerCta, instrShape, transposed, ctaLayout,
373373
tilesPerWarp, elementBitWidth);
374374
})
375+
.def("get_amd_mfma_scale_layout",
376+
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
377+
unsigned mfmaMDim, std::vector<unsigned> &tilesPerWarp,
378+
std::vector<unsigned> &warpsPerCTA) -> py::object {
379+
auto ctx = self.getContext();
380+
auto ll = ttg::chooseScaledMfmaScaleLayout(
381+
ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA);
382+
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
383+
return layoutToGluon(attr);
384+
})
375385
.def("get_amd_wmma_layout",
376386
[](GluonOpBuilder &self, unsigned version, bool transposed,
377387
std::vector<unsigned> &warpsPerCta,
@@ -385,6 +395,15 @@ void init_gluon_ir(py::module &&m) {
385395
return ttg::AMDWmmaEncodingAttr::get(
386396
ctx, version, transposed, warpsPerCta, ctaLayout, instrShape);
387397
})
398+
.def("get_amd_wmma_scale_layout",
399+
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
400+
std::vector<unsigned> &warpsPerCTA) -> py::object {
401+
auto ctx = self.getContext();
402+
auto ll = ttg::chooseScaledWmmaScaleLayout(ctx, opIdx, warpsPerCTA,
403+
shape);
404+
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
405+
return layoutToGluon(attr);
406+
})
388407
.def("get_padded_shared_layout",
389408
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
390409
std::vector<unsigned> &paddings,

python/test/gluon/test_core.py

Lines changed: 90 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_hopper_or_newer,
1818
is_hopper,
1919
)
20+
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
2021
from triton.experimental import gluon
2122
from triton.experimental.gluon import language as ttgl
2223
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mma_v2
@@ -629,145 +630,108 @@ def kernel(a_ptr, b_ptr, c_ptr, #
629630

630631

631632
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires CDNA4")
632-
@pytest.mark.parametrize("M, N, K, rhs_scale, mxfp_type, normal_type", [(32, 32, 128, rhs_scale, mxfp_type, normal_type)
633-
for rhs_scale in [True, False]
634-
for mxfp_type in ["e2m1"]
635-
for normal_type in ["e4m3", "e5m2"]])
636-
def test_amd_mfma_scaled(M, N, K, rhs_scale, mxfp_type, normal_type):
637-
device = 'cuda'
638-
639-
@triton.jit
640-
def triton_kernel(a_base, stride_am, stride_ak, a_scale, #
641-
b_base, stride_bk, stride_bn, b_scale, #
642-
out, #
643-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
644-
type_a: tl.constexpr, type_b: tl.constexpr):
645-
DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1
646-
DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1
647-
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A
648-
PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B
649-
a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_am + \
650-
tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_ak
651-
b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_bk + \
652-
tl.arange(0, BLOCK_N)[None, :] * stride_bn
653-
654-
a = tl.load(a_ptr)
655-
b = tl.load(b_ptr)
656-
SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
657-
if a_scale is not None:
658-
scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0,
659-
SCALE_BLOCK_K)[None, :]
660-
a_scale = tl.load(scale_a_ptr)
661-
if b_scale is not None:
662-
scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0,
663-
SCALE_BLOCK_K)[None, :]
664-
b_scale = tl.load(scale_b_ptr)
665-
c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b)
666-
out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
667-
tl.store(out_ptr, c.to(tl.bfloat16))
633+
@pytest.mark.parametrize("M, N, K", [(32, 32, 128)])
634+
@pytest.mark.parametrize("a_type, b_type", [(a_type, b_type)
635+
for a_type in ["e2m1", "e4m3", "e5m2"]
636+
for b_type in ["e2m1", "e4m3", "e5m2"]])
637+
@pytest.mark.parametrize("has_scale", [True, False])
638+
def test_amd_mfma_scaled(M, N, K, a_type, b_type, has_scale, device='cuda'):
668639

669640
@gluon.jit
670-
def gluon_kernel(a_base, stride_am, stride_ak, a_scale, #
671-
b_base, stride_bk, stride_bn, b_scale, #
672-
out, #
673-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
674-
type_a: tl.constexpr, type_b: tl.constexpr):
675-
DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1
676-
DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1
677-
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A
678-
PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B
679-
SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
641+
def kernel(out_ptr, a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, #
642+
M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, #
643+
a_type: tl.constexpr, b_type: tl.constexpr):
644+
DIV_FACTOR_A: tl.constexpr = 2 if a_type == "e2m1" else 1
645+
DIV_FACTOR_B: tl.constexpr = 2 if b_type == "e2m1" else 1
646+
K_A: tl.constexpr = K // DIV_FACTOR_A
647+
K_B: tl.constexpr = K // DIV_FACTOR_B
648+
649+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True,
650+
warps_per_cta=[2, 2])
680651

681652
a_unpacked_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [8, 8], [4, 1], [1, 0])
682653
a_packed_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [8, 8], [4, 1], [1, 0])
683-
a_layout: ttgl.constexpr = a_packed_layout if type_a == "e2m1" else a_unpacked_layout
684-
685-
a_scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout(
686-
reg_bases=[], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp_bases=[[0, 0], [16, 0]],
687-
block_bases=[], shape=[32, 4])
654+
a_load_layout: ttgl.constexpr = a_packed_layout if a_type == "e2m1" else a_unpacked_layout
655+
a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16)
656+
a_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(a_layout, [M, K // 32])
688657

689658
b_unpacked_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [32, 2], [4, 1], [1, 0])
690659
b_packed_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0])
691-
b_layout: ttgl.constexpr = b_packed_layout if type_b == "e2m1" else b_unpacked_layout
692-
693-
b_scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout(
694-
reg_bases=[], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp_bases=[[16, 0], [0, 0]],
695-
block_bases=[], shape=[32, 4])
696-
697-
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True,
698-
warps_per_cta=[2, 2])
699-
700-
zero = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=mfma_layout)
701-
702-
a_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, a_layout))[:, None] * stride_am + \
703-
ttgl.arange(0, PACKED_BLOCK_K_A, layout=ttgl.SliceLayout(0, a_layout))[None, :] * stride_ak
704-
a = ttgl.amd.cdna4.buffer_load(a_base, a_offsets)
705-
a = ttgl.convert_layout(a, ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16))
706-
707-
b_offsets = ttgl.arange(0, PACKED_BLOCK_K_B, layout=ttgl.SliceLayout(1, b_layout))[:, None] * stride_bk + \
708-
ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, b_layout))[None, :] * stride_bn
709-
b = ttgl.amd.cdna4.buffer_load(b_base, b_offsets)
710-
b = ttgl.convert_layout(b, ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16))
711-
712-
if a_scale is not None:
713-
a_scale_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, a_scale_layout))[:, None] * SCALE_BLOCK_K + \
714-
ttgl.arange(0, SCALE_BLOCK_K, layout=ttgl.SliceLayout(0, a_scale_layout))[None, :]
715-
a_scale = ttgl.amd.cdna4.buffer_load(a_scale, a_scale_offsets)
660+
b_load_layout: ttgl.constexpr = b_packed_layout if b_type == "e2m1" else b_unpacked_layout
661+
b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16)
662+
b_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(b_layout, [N, K // 32])
663+
664+
a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, a_load_layout))[:, None]
665+
a_offs_k = ttgl.arange(0, K_A, layout=ttgl.SliceLayout(0, a_load_layout))[None, :]
666+
a = ttgl.amd.cdna4.buffer_load(a_ptr, a_offs_m * K_A + a_offs_k)
667+
a = ttgl.convert_layout(a, a_layout)
668+
669+
b_offs_k = ttgl.arange(0, K_B, layout=ttgl.SliceLayout(1, b_load_layout))[:, None]
670+
b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, b_load_layout))[None, :]
671+
b = ttgl.amd.cdna4.buffer_load(b_ptr, b_offs_k * N + b_offs_n)
672+
b = ttgl.convert_layout(b, b_layout)
673+
674+
a_scale = None
675+
if a_scale_ptr is not None:
676+
a_scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, a_scale_layout))[:, None]
677+
a_scale_offs_k = ttgl.arange(0, K // 32, layout=ttgl.SliceLayout(0, a_scale_layout))[None, :]
678+
a_scale = ttgl.amd.cdna4.buffer_load(a_scale_ptr, a_scale_offs_m * (K // 32) + a_scale_offs_k)
679+
680+
b_scale = None
681+
if b_scale_ptr is not None:
682+
b_scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, b_scale_layout))[:, None]
683+
b_scale_offs_k = ttgl.arange(0, K // 32, layout=ttgl.SliceLayout(0, b_scale_layout))[None, :]
684+
b_scale = ttgl.amd.cdna4.buffer_load(b_scale_ptr, b_scale_offs_n * (K // 32) + b_scale_offs_k)
685+
686+
zero = ttgl.zeros([M, N], dtype=ttgl.float32, layout=mfma_layout)
687+
c = ttgl.amd.cdna4.mfma_scaled(a, a_scale, a_type, b, b_scale, b_type, zero)
688+
c = c.to(out_ptr.dtype.element_ty)
689+
690+
out_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, mfma_layout))[:, None]
691+
out_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, mfma_layout))[None, :]
692+
ttgl.amd.cdna4.buffer_store(c, out_ptr, out_offs_m * N + out_offs_n)
693+
694+
def _create_mxfp_operand(operand: int, m: int, n: int, dtype: str):
695+
size = (m, n)
696+
if dtype == 'e4m3':
697+
v = torch.randint(20, 40, size, dtype=torch.uint8)
698+
v_ref = v.view(torch.float8_e4m3fn).to(torch.float32)
699+
elif dtype == 'e5m2':
700+
v = torch.randint(20, 40, size, dtype=torch.uint8)
701+
v_ref = v.view(torch.float8_e5m2).to(torch.float32)
716702
else:
717-
a_scale = ttgl.full([BLOCK_M, SCALE_BLOCK_K], 127, dtype=ttgl.int8, layout=a_scale_layout)
718-
719-
if b_scale is not None:
720-
b_scale_offsets = ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(1, b_scale_layout))[:, None] * SCALE_BLOCK_K + \
721-
ttgl.arange(0, SCALE_BLOCK_K, layout=ttgl.SliceLayout(0, b_scale_layout))[None, :]
722-
b_scale = ttgl.amd.cdna4.buffer_load(b_scale, b_scale_offsets)
723-
else:
724-
b_scale = ttgl.full([BLOCK_M, SCALE_BLOCK_K], 127, dtype=ttgl.int8, layout=b_scale_layout)
725-
726-
c = ttgl.amd.cdna4.mfma_scaled(a, a_scale, type_a, b, b_scale, type_b, zero)
727-
c = c.to(out.dtype.element_ty)
728-
729-
out_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, mfma_layout))[:, None] * BLOCK_N + \
730-
ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, mfma_layout))[None, :]
731-
ttgl.amd.cdna4.buffer_store(c, out, out_offsets)
703+
assert dtype == 'e2m1'
704+
pack_dim = 1 if operand == 0 else 0
705+
v_mxfp4 = MXFP4Tensor(size=size).random()
706+
v = v_mxfp4.to_packed_tensor(pack_dim)
707+
v_ref = v_mxfp4.to(torch.float32)
708+
return v.to(device), v_ref.to(device)
709+
710+
def _create_mxfp_scale(operand: int, m: int, n: int):
711+
size = (m, n // 32)
712+
scale = MXScaleTensor(size=tuple(size)).random(1 / 32, 32)
713+
scale_ref = scale.to(torch.float32).repeat_interleave(32, dim=1)
714+
scale_ref = scale_ref.T.contiguous() if operand == 1 else scale_ref
715+
return scale.data.to(device), scale_ref.to(device)
732716

733717
torch.manual_seed(0)
734-
735-
type_a = normal_type if rhs_scale else mxfp_type
736-
type_b = mxfp_type if rhs_scale else normal_type
737-
738-
DIV_FACTOR_A = 2 if type_a == "e2m1" else 1
739-
DIV_FACTOR_B = 2 if type_b == "e2m1" else 1
740-
x = torch.randint(20, 40, (M, K // DIV_FACTOR_A), dtype=torch.uint8, device=device)
741-
y = torch.randint(20, 40, (K // DIV_FACTOR_B, N), dtype=torch.uint8, device=device)
742-
743-
min_scale, max_scale = (0, 142)
744-
scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device)
745-
scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device)
746-
if rhs_scale:
747-
scale_x = None
718+
a, a_ref = _create_mxfp_operand(0, M, K, a_type)
719+
b, b_ref = _create_mxfp_operand(1, K, N, b_type)
720+
721+
if has_scale:
722+
a_scale, a_scale_ref = _create_mxfp_scale(0, M, K)
723+
b_scale, b_scale_ref = _create_mxfp_scale(1, N, K)
724+
out = torch.empty((M, N), dtype=torch.float32, device=device)
725+
compiled = kernel[(1, )](out, a, b, a_scale, b_scale, M, N, K, a_type, b_type, num_warps=4)
726+
out_ref = torch.matmul(a_ref * a_scale_ref, b_ref * b_scale_ref)
727+
torch.testing.assert_close(out, out_ref)
748728
else:
749-
scale_y = None
750-
751-
def make_finite(x, dtype):
752-
if dtype not in ("e5m2", "e4m3"):
753-
return x
754-
mask = 0x7C if dtype == "e5m2" else 0x7F
755-
finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask
756-
x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x)
757-
x.copy_(x_finite)
758-
return x
759-
760-
x = make_finite(x, type_a)
761-
y = make_finite(y, type_b)
762-
763-
z = torch.zeros((M, N), dtype=torch.bfloat16, device=device)
764-
pgm = gluon_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b)
765-
assert "v_mfma_scale_f32_16x16x128_f8f6f4" in pgm.asm["amdgcn"]
766-
767-
z_ref = torch.zeros((M, N), dtype=torch.bfloat16, device=device)
768-
triton_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z_ref, M, N, K, type_a, type_b)
729+
out = torch.empty((M, N), dtype=torch.float32, device=device)
730+
compiled = kernel[(1, )](out, a, b, None, None, M, N, K, a_type, b_type, num_warps=4)
731+
out_ref = torch.matmul(a_ref, b_ref)
732+
torch.testing.assert_close(out, out_ref)
769733

770-
torch.testing.assert_close(z, z_ref, rtol=1e-5, atol=1e-5)
734+
assert 'v_mfma_scale_f32_16x16x128_f8f6f4' in compiled.asm['amdgcn']
771735

772736

773737
def test_math_fast_expf():

0 commit comments

Comments
 (0)