Skip to content

Commit cf73993

Browse files
authored
[AMD] Support gfx950 scaled mfma instructions for mxfp4 (triton-lang#5845)
Support the new `rocdl.mfma.scale.f32.16x16x128.f8f6f4` and `rocdl.mfma.scale.f32.32x32x64.f8f6f4` instructions for gfx950. These instructions enable scales for both lhs and rhs operands. This PR first supports mxfp4.
1 parent 089b8bc commit cf73993

File tree

16 files changed

+627
-30
lines changed

16 files changed

+627
-30
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ It is characterized by the following parameters:
869869
- 1.0: gfx908, i.e. MI100
870870
- 2.0: gfx90a: i.e. MI200, MI210, MI250
871871
- 3.0: gfx940, gfx941, gfx942: MI300
872+
- 4.0: gfx950: MI350
872873
- `warpsPerCTA` indicates the warp layout in the block.
873874
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
874875
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
@@ -938,7 +939,7 @@ The data will be distributed between threads as follows:
938939

939940
Example 3:
940941
Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4.
941-
The data will be distributed between threads as follows(note that each element is duploicated in 16 threads):
942+
The data will be distributed between threads as follows(note that each element is duplicated in 16 threads):
942943
Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4.
943944
The data will be distributed between threads as follows(note that each element is duplicated in 16 threads):
944945

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,8 +1458,8 @@ AMDMfmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
14581458
llvm::ArrayRef<unsigned int> warpsPerCTA,
14591459
unsigned mDim, unsigned nDim, bool isTransposed,
14601460
mlir::triton::gpu::CTALayoutAttr) {
1461-
if (!(versionMajor >= 0 && versionMajor <= 3)) {
1462-
return emitError() << "major version must be in the [0, 3] range";
1461+
if (!(versionMajor >= 0 && versionMajor <= 4)) {
1462+
return emitError() << "major version must be in the [0, 4] range";
14631463
}
14641464
if (versionMinor != 0) {
14651465
return emitError() << "minor version must be 0";

python/test/unit/language/test_matmul.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton.tools.experimental_descriptor
77
from test_mxfp import MXFP4Tensor, MXScaleTensor
88
import re
9-
from triton._internal_testing import is_cuda, is_hip, is_hip_mi200
9+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi200, is_hip_mi350, is_hip_cdna
1010

1111

1212
def f8_to_f16(x, dtype):
@@ -711,8 +711,18 @@ def block_scale_fp4_matmul( #
711711
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
712712
@pytest.mark.parametrize(("scale_type", "VEC_SIZE"), [("float8_e8m0fnu", 32), ("float8_e4m3fn", 16)],
713713
ids=["mxfp4", "nvfp4"])
714-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
715-
def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, device):
714+
@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else []))
715+
def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, nonKDim, device):
716+
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
717+
pytest.skip("Requires compute capability >= 10")
718+
elif is_hip():
719+
if not is_hip_mi350():
720+
pytest.skip("Scaled fp4 matmul is only natively supported on MI350")
721+
if scale_type != 'float8_e8m0fnu':
722+
pytest.skip("MI350 only supports E8M0 scale")
723+
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
724+
pytest.skip(f"MI350 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
725+
716726
NUM_STAGES = 1
717727
torch.manual_seed(42)
718728
a_mxfp4 = MXFP4Tensor(size=(M, K), device=device).random()
@@ -744,9 +754,12 @@ def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_typ
744754

745755
output = a.new_empty((M, N), dtype=torch.float32)
746756
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
757+
kernel_kwargs = {}
758+
if is_hip():
759+
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
747760
block_scale_fp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1),
748761
b.stride(0), b.stride(1), output.stride(0), output.stride(1), VEC_SIZE, BLOCK_M,
749-
BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES)
762+
BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, **kernel_kwargs)
750763

751764
torch.testing.assert_close(ref_out, output, atol=1e-2, rtol=1e-2)
752765

python/triton/_internal_testing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,15 @@ def is_hip_mi300():
6060
return target.arch in ('gfx940', 'gfx941', 'gfx942')
6161

6262

63+
def is_hip_mi350():
64+
target = get_current_target()
65+
if target is None or target.backend != 'hip':
66+
return False
67+
return target.arch in ('gfx950')
68+
69+
6370
def is_hip_cdna():
64-
return is_hip_mi200() or is_hip_mi300()
71+
return is_hip_mi200() or is_hip_mi300() or is_hip_mi350()
6572

6673

6774
def is_xpu():

test/TritonGPU/invalid-attributes.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
// -----
6666

67-
// expected-error@+1 {{major version must be in the [0, 3] range}}
67+
// expected-error@+1 {{major version must be in the [0, 4] range}}
6868
#mfma = #ttg.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}>
6969

7070
// -----

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def parse_options(self, opts) -> Any:
9999

100100
if "supported_fp8_dtypes" not in opts:
101101
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
102-
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
102+
if self.target.arch in ('gfx940', 'gfx941', 'gfx942', 'gfx950'):
103103
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
104104
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
105105

third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ enum class MfmaTypeId : uint32_t {
2020
Fp8Fp8TyId,
2121
Fp8Bf8TyId,
2222
Bf8Fp8TyId,
23-
Bf8Bf8TyId
23+
Bf8Bf8TyId,
24+
F8F6F4TyId,
2425
};
2526

2627
struct MfmaInsnGroupSelectKey {

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1515
const LLVMTypeConverter *typeConverter,
1616
ConversionPatternRewriter &rewriter);
1717

18+
LogicalResult convertScaledMFMA(triton::DotScaledOp op,
19+
triton::DotScaledOp::Adaptor adaptor,
20+
const LLVMTypeConverter *typeConverter,
21+
ConversionPatternRewriter &rewriter);
22+
1823
LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1924
const LLVMTypeConverter *typeConverter,
2025
ConversionPatternRewriter &rewriter);
2126
} // namespace mlir::triton::AMD
2227

2328
namespace {
2429
struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
25-
using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
30+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2631

2732
LogicalResult
2833
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
@@ -47,6 +52,25 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
4752
"Unsupported DotOp found when converting TritonGPU to LLVM.");
4853
}
4954
};
55+
56+
struct ScaledDotOpConversion
57+
: public ConvertOpToLLVMPattern<triton::DotScaledOp> {
58+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
59+
int mfmaVersion;
60+
int nonKDim;
61+
int kPack;
62+
63+
ScaledDotOpConversion(LLVMTypeConverter &typeConverter, int mfmaVersion,
64+
int nonKDim, int kPack, PatternBenefit benefit = 1)
65+
: ConvertOpToLLVMPattern(typeConverter, benefit),
66+
mfmaVersion(mfmaVersion), nonKDim(nonKDim), kPack(kPack) {}
67+
68+
LogicalResult
69+
matchAndRewrite(triton::DotScaledOp op, OpAdaptor adaptor,
70+
ConversionPatternRewriter &rewriter) const override {
71+
return AMD::convertScaledMFMA(op, adaptor, getTypeConverter(), rewriter);
72+
}
73+
};
5074
} // namespace
5175

5276
namespace mlir::triton::AMD {
@@ -55,5 +79,6 @@ void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5579
ModuleAxisInfoAnalysis &axisInfoAnalysis,
5680
PatternBenefit benefit) {
5781
patterns.add<DotOpConversion>(typeConverter, benefit);
82+
patterns.add<ScaledDotOpConversion>(typeConverter, benefit);
5883
}
5984
} // namespace mlir::triton::AMD

0 commit comments

Comments
 (0)