Skip to content

Commit d5ba6ac

Browse files
authored
[BACKEND][LAYOUT] Use LL for AMDMfma related layout conversions (#5210)
1 parent 9c7a8c6 commit d5ba6ac

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -374,24 +374,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
374374
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
375375
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
376376
// completed before we can remove the layoutIsOK check:
377-
// 1. Support for AMD's MFMA and WMMA
377+
// 1. Support for AMD's WMMA
378378
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
379-
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
380-
if (useLegacyMMAConversion) {
381-
return false;
382-
}
383-
return true;
379+
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
380+
return !useLegacyMMAConversion;
384381
}
385382
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
386-
if (auto nvidiaMma =
387-
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
388-
if (useLegacyMMAConversion) {
389-
return false;
390-
}
383+
auto parent = dotOperand.getParent();
384+
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
385+
return false;
386+
}
387+
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
391388
if (nvidiaMma.isAmpere()) {
392389
return true;
393390
}
394391
}
392+
if (isa<AMDMfmaEncodingAttr>(parent)) {
393+
return true;
394+
}
395395
return false;
396396
}
397397
if (isa<BlockedEncodingAttr>(layout)) {

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
77
tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) {
88
// CHECK-NOT: store
99
// CHECK-NOT: load
10+
// CHECK: llvm.return
1011
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop>
1112
tt.return
1213
}
@@ -21,6 +22,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
2122
tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) {
2223
// CHECK: store
2324
// CHECK: load
25+
// CHECK: llvm.return
2426
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop>
2527
tt.return
2628
}

0 commit comments

Comments
 (0)