Skip to content

Commit 8a16d88

Browse files
authored
[AMD][Backend] Enable XF32 (TF32) Support for CDNA3 GPUs (#5637)
# Overview AMD MI300-series GPUs support XF32 (TF32) mfma instructions in hardware, so we should utilize them if they are there. TF32 provides a ~1.4x improvement to matmuls over FP32 in some examples. # BC breaking: changing the default input precision behavior According to the [Triton docs](https://github.com/triton-lang/triton/blob/6556ec6050649e1fc42feb05a62ab9cc6908a722/python/triton/language/core.py#L1714) "For devices that do have tensor cores, the default precision is tf32". Enabling XF32 (TF32) on the MI300 is BC breaking in this case, as before the ops would execute with FP32 precision and will now execute with TF32. # Testing I've added lit tests and enabled TF32 for MI300 in the python unit tests
1 parent 4a80233 commit 8a16d88

File tree

8 files changed

+96
-22
lines changed

8 files changed

+96
-22
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3407,9 +3407,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
34073407
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
34083408
if capability[0] < 9 and in_dtype == 'float8e4nv':
34093409
pytest.skip("float8e4nv not supported on sm <= 80")
3410+
34103411
if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'):
34113412
pytest.skip("float8e4nv and float8e5 not supported on HIP")
3412-
if is_hip() and (input_precision != "ieee"):
3413+
if is_hip() and not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_mi300())):
34133414
pytest.skip(f"{input_precision} not supported on HIP")
34143415
if is_hip() and (kpack == 2 and in_dtype == 'int8' and K < 64):
34153416
pytest.skip("kpack too large for K")

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,7 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
17151715
the device does not have Tensor Cores or the inputs are not of dtype f32,
17161716
this option is ignored. For devices that do have tensor cores, the
17171717
default precision is tf32.
1718-
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`.
1718+
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
17191719
:param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
17201720
Only one of :code:`input_precision` and :code:`allow_tf32` can be
17211721
specified (i.e. at least one must be :code:`None`).

test/TritonGPU/amd/mfma-xf32.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
2+
3+
// CHECK-LABEL:mfma_xf32
4+
5+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
6+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
7+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
8+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
9+
tt.func public @mfma_xf32(
10+
%arg0: tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>,
11+
%arg1: tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) {
12+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
13+
// Check that we generate xf32 instructions
14+
// CHECK: rocdl.mfma.f32.16x16x8.xf32
15+
%dot = tt.dot %arg0, %arg1, %cst_0, inputPrecision = tf32 :
16+
tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x64xf32, #mma>
17+
tt.return
18+
}
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL:mfma_not_xf32
24+
25+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
26+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
27+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
28+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
29+
tt.func public @mfma_not_xf32(
30+
%arg0: tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>,
31+
%arg1: tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) {
32+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
33+
// Check that we don't generate xf32 instructions if the input precision is "ieee"
34+
// CHECK: rocdl.mfma.f32.16x16x4f32
35+
%dot = tt.dot %arg0, %arg1, %cst_0, inputPrecision = ieee :
36+
tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x64xf32, #mma>
37+
tt.return
38+
}
39+
}

third_party/amd/backend/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def __init__(self, target: GPUTarget) -> None:
112112
def parse_options(self, opts) -> Any:
113113
args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
114114

115+
# Enable XF32 (TF32) for CDNA3 GPUs
116+
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
117+
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
118+
allowed_dot_input_precisions.update({'tf32'})
119+
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
120+
115121
if "supported_fp8_dtypes" not in opts:
116122
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
117123
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):

third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace mlir {
1313

1414
enum class MfmaTypeId : uint32_t {
1515
Fp32TyId = 0,
16+
Xf32TyId,
1617
Fp16TyId,
1718
Bf16TyId,
1819
I8TyId,
@@ -79,7 +80,7 @@ class MfmaInsn {
7980
public:
8081
static FailureOr<MfmaInsn> selectMfma(unsigned mDim, unsigned nDim,
8182
Type elementTypeA, Type elementTypeB,
82-
int mfmaVersion);
83+
int mfmaVersion, bool allowXF32);
8384
MfmaInsn(Type elementTypeA, Type elementTypeB, const MfmaInsnAttr &attr);
8485
unsigned getKDim();
8586
unsigned getMDim();

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,11 @@ struct DotOpMFMAConversionHelper {
183183
auto elemTyA = aTensorTy.getElementType();
184184
auto elemTyB = bTensorTy.getElementType();
185185

186+
bool allowXF32 =
187+
op.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
186188
StringRef mfmaInsnName;
187-
auto maybeMfmaInsn =
188-
MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB, mfmaVersion);
189+
auto maybeMfmaInsn = MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB,
190+
mfmaVersion, allowXF32);
189191
if (failed(maybeMfmaInsn))
190192
llvm::report_fatal_error("No match found in MFMA database\n");
191193

@@ -195,6 +197,11 @@ struct DotOpMFMAConversionHelper {
195197
auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
196198
auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
197199
int kWidth = aEncoding.getKWidth();
200+
201+
// If we are using XF32, the kWidth (and kBase) is double that of F32.
202+
if (aTensorTy.getElementType().isF32() && allowXF32)
203+
kWidth *= 2;
204+
198205
auto rank = aTensorTy.getShape().size();
199206
const auto kDimOperandSize = aTensorTy.getShape()[rank - 1];
200207
const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand(kWidth, 0)[1];
@@ -216,17 +223,17 @@ struct DotOpMFMAConversionHelper {
216223

217224
auto operandA = getValuesFromDotOperandLayoutStruct(
218225
loadedA, numRepB, numRepM, numRepK, kWidth, kBase,
219-
aTensorTy.getElementType());
226+
aTensorTy.getElementType(), allowXF32);
220227
auto operandB = getValuesFromDotOperandLayoutStruct(
221228
loadedB, numRepB, numRepN, numRepK, kWidth, kBase,
222-
aTensorTy.getElementType());
229+
aTensorTy.getElementType(), allowXF32);
223230

224231
auto dstElemTy = dTensorTy.getElementType();
225232
auto fc = unpackLLElements(loc, loadedC, rewriter);
226233

227234
unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout);
228235
// compute number of output elements that each thread holds for one MFMA
229-
// instruction. subBlocks
236+
// instruction.
230237
const int subBlocks =
231238
getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim);
232239
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
@@ -370,7 +377,8 @@ struct DotOpMFMAConversionHelper {
370377
/// appropriate for mfma instructions
371378
SmallVector<ValueTable>
372379
getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0, int n1,
373-
int kWidth, int kBase, Type type) const {
380+
int kWidth, int kBase, Type type,
381+
bool allowXF32) const {
374382
auto elems = unpackLLElements(loc, value, rewriter);
375383
int kpack = kWidth / kBase;
376384
SmallVector<ValueTable> dotOpVals(kpack);
@@ -388,13 +396,15 @@ struct DotOpMFMAConversionHelper {
388396
}
389397

390398
Value convertedElems;
391-
if (type.isF32()) {
399+
if (type.isF32() && !allowXF32) {
392400
for (int k = 0; k < kpack; ++k)
393401
dotOpVals[k][{b, i, j}] =
394402
extract_element(type, rawElems, i32_val(k));
395403
} else {
396404
SmallVector<Value> vals;
397-
if (type.getIntOrFloatBitWidth() == 8) {
405+
if (type.isF32() && allowXF32) {
406+
vals = extractOperands(rawElems, kWidth, kBase, f32_ty);
407+
} else if (type.getIntOrFloatBitWidth() == 8) {
398408
vals = extractOperands(rawElems, kWidth, kBase, i8_ty);
399409
} else if (type.isBF16()) {
400410
vals = extractOperands(rawElems, kWidth, kBase, bf16_ty);

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps) {
101101
FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
102102
Type aElemType, Type bElemType,
103103
int inputKSize, int mfmaVersion,
104-
int enforcedNonKDim) {
104+
bool allowXF32, int enforcedNonKDim) {
105105
// number of matrix elements along k dim per one MFMA intruction
106106
unsigned kDim = 0;
107107

@@ -128,8 +128,8 @@ FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
128128
if (mDim == 0 || nDim == 0)
129129
return failure();
130130

131-
auto maybeMfmaInsn =
132-
MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion);
131+
auto maybeMfmaInsn = MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType,
132+
mfmaVersion, allowXF32);
133133
if (failed(maybeMfmaInsn))
134134
llvm::report_fatal_error("No match found in MFMA database\n");
135135

@@ -146,19 +146,23 @@ FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
146146
FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
147147
int nonKDim) {
148148
RankedTensorType aType = dot.getA().getType();
149+
bool allowXF32 =
150+
dot.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
149151
return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(),
150152
dot.getB().getType().getElementType(),
151-
aType.getShape().back(), mfmaVersion, nonKDim);
153+
aType.getShape().back(), mfmaVersion, allowXF32,
154+
nonKDim);
152155
}
153156

154157
FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotScaledOp dot, int mfmaVersion,
155158
int nonKDim, bool useFp16) {
156159
// For scaled dot, we handle it with fp16 or bf16 emulation for now.
157160
Builder b(dot.getContext());
158161
Type elemType = useFp16 ? b.getF16Type() : b.getBF16Type();
159-
return chooseMfmaInstruction(
160-
dot.getC().getType(), /*aElemType=*/elemType, /*bElemType=*/elemType,
161-
dot.getLhs().getType().getShape().back(), mfmaVersion, nonKDim);
162+
return chooseMfmaInstruction(dot.getC().getType(), /*aElemType=*/elemType,
163+
/*bElemType=*/elemType,
164+
dot.getLhs().getType().getShape().back(),
165+
mfmaVersion, /*allowXF32=*/false, nonKDim);
162166
}
163167

164168
using OperandTypesVector = SmallVector<Type, 4>;

third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
namespace mlir {
44

55
static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA,
6-
mlir::Type dataTypeB) {
6+
mlir::Type dataTypeB,
7+
bool allowXF32) {
78
if (dataTypeA.isF32() && dataTypeB.isF32()) {
8-
return MfmaTypeId::Fp32TyId;
9+
if (allowXF32)
10+
return MfmaTypeId::Xf32TyId;
11+
else
12+
return MfmaTypeId::Fp32TyId;
913
}
1014
if (dataTypeA.isF16() && dataTypeB.isF16()) {
1115
return MfmaTypeId::Fp16TyId;
@@ -39,6 +43,13 @@ using MfmaInsnGroupMap = llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnAttr,
3943

4044
auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & {
4145
static MfmaInsnGroupMap MfmaInsnMap{
46+
// xf32
47+
// mfma.xf32.16x16x8xf32
48+
{{16, 16, MfmaTypeId::Xf32TyId, 3},
49+
{16, 16, 8, 2, ROCDL::mfma_f32_16x16x8_xf32::getOperationName()}},
50+
// mfma.xf32.32x32x4.xf32
51+
{{32, 32, MfmaTypeId::Xf32TyId, 3},
52+
{32, 32, 4, 2, ROCDL::mfma_f32_32x32x4_xf32::getOperationName()}},
4253
// f32
4354
// mfma_f32_32x32x2f32
4455
{{32, 32, MfmaTypeId::Fp32TyId, 1},
@@ -219,6 +230,7 @@ std::pair<mlir::Type, mlir::Type> TypesFromMfmaId(mlir::MLIRContext *ctx,
219230
auto f32 = Float32Type::get(ctx);
220231
auto i8 = IntegerType::get(ctx, 8, IntegerType::Signed);
221232
switch (id) {
233+
case MfmaTypeId::Xf32TyId:
222234
case MfmaTypeId::Fp32TyId:
223235
return {f32, f32};
224236
case MfmaTypeId::Fp16TyId:
@@ -242,9 +254,10 @@ std::pair<mlir::Type, mlir::Type> TypesFromMfmaId(mlir::MLIRContext *ctx,
242254

243255
FailureOr<MfmaInsn> MfmaInsn::selectMfma(unsigned mDim, unsigned nDim,
244256
Type elementTypeA, Type elementTypeB,
245-
int mfmaVersion) {
257+
int mfmaVersion, bool allowXF32) {
246258
auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap();
247-
MfmaTypeId mfmaId = chooseAppropriateMfmaId(elementTypeA, elementTypeB);
259+
MfmaTypeId mfmaId =
260+
chooseAppropriateMfmaId(elementTypeA, elementTypeB, allowXF32);
248261
MfmaInsnGroupSelectKey key = {mDim, nDim, mfmaId, mfmaVersion};
249262
auto it = mfmaInsnAttrMap.find(key);
250263
if (it == mfmaInsnAttrMap.end())

0 commit comments

Comments
 (0)