Skip to content

Commit 152ef2d

Browse files
authored
[AMD] Enable shared->MFMA dot operand conversion through LinearLayout (#4983)
This PR: - Introduces fallback from normal TTG->LLVM converter in case it does not support given local_load. - Enables conversion of MFMA dot layout to Linear Layout in local_load pattern.
1 parent 258a5bc commit 152ef2d

File tree

6 files changed

+129
-44
lines changed

6 files changed

+129
-44
lines changed

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,30 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
109109
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
110110
}
111111

112+
// FIXME [Dot LL]
113+
// Do for all DotOperandEncodingAttr once we have LLs for all of them
114+
static bool isSupportedDotOpLayout(Attribute layout) {
115+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
116+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
117+
return mma.isAmpere() && dot.getKWidth() == 8;
118+
}
119+
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
120+
return true;
121+
}
122+
return false;
123+
};
124+
112125
LogicalResult
113126
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
114127
ConversionPatternRewriter &rewriter) const override {
115128
MemDescType srcTy = op.getSrc().getType();
116129
RankedTensorType dstTy = op.getType();
117130
Attribute srcLayout = srcTy.getEncoding();
118131
Attribute dstLayout = dstTy.getEncoding();
119-
// FIXME [Dot LL]
120-
// Do for all DotOperandEncodingAttr once we have LLs for all of them
121-
auto isAmpereLargeKWidth = [](Attribute layout) {
122-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123-
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
124-
return mma.isAmpere() && dot.getKWidth() == 8;
125-
}
126-
}
127-
return false;
128-
};
129132
if (isa<SharedEncodingAttr>(srcLayout) &&
130133
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131134
dstLayout) ||
132-
isAmpereLargeKWidth(dstLayout))) {
135+
isSupportedDotOpLayout(dstLayout))) {
133136
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
134137
rewriter);
135138
}
@@ -167,10 +170,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
167170
auto srcTy = op.getSrc().getType();
168171
auto dstTy = op.getResult().getType();
169172
auto dstShape = dstTy.getShape();
170-
assert(dstShape.size() <= 2 &&
171-
"Unexpected rank of ConvertLayout(shared->blocked)");
172173
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
173174
auto dstLayout = dstTy.getEncoding();
175+
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
176+
"Unexpected rank of ConvertLayout(shared->distributed)");
174177
auto inOrd = getOrder(srcSharedLayout);
175178

176179
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
@@ -184,31 +187,36 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
184187
// FIXME [Dot LL]
185188
// Ampere case
186189
// In this case, we need to pack the outputs into i32
187-
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
188-
if (elemLlvmTy.isInteger(8)) {
189-
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190-
return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
191-
or_(shl(zext(i32_ty, a3), i32_val(16)),
192-
shl(zext(i32_ty, a4), i32_val(24))));
193-
};
194-
SmallVector<Value> outVals32(outVals.size() / 4);
195-
for (int i = 0; i < outVals32.size(); ++i) {
196-
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
197-
outVals[4 * i + 2], outVals[4 * i + 3]);
198-
}
199-
outVals = outVals32;
200-
} else {
201-
assert(elemLlvmTy.isBF16() && "Unexpected element type");
202-
auto concat = [&](Value a, Value b) {
203-
return or_(zext(i32_ty, bitcast(a, i16_ty)),
204-
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
205-
};
190+
if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding())) {
191+
if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent())) {
192+
if (parent.isAmpere()) {
193+
if (elemLlvmTy.isInteger(8)) {
194+
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
195+
return or_(
196+
or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
197+
or_(shl(zext(i32_ty, a3), i32_val(16)),
198+
shl(zext(i32_ty, a4), i32_val(24))));
199+
};
200+
SmallVector<Value> outVals32(outVals.size() / 4);
201+
for (int i = 0; i < outVals32.size(); ++i) {
202+
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
203+
outVals[4 * i + 2], outVals[4 * i + 3]);
204+
}
205+
outVals = outVals32;
206+
} else {
207+
assert(elemLlvmTy.isBF16() && "Unexpected element type");
208+
auto concat = [&](Value a, Value b) {
209+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
210+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
211+
};
206212

207-
SmallVector<Value> outVals32(outVals.size() / 2);
208-
for (int i = 0; i < outVals32.size(); ++i) {
209-
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
213+
SmallVector<Value> outVals32(outVals.size() / 2);
214+
for (int i = 0; i < outVals32.size(); ++i) {
215+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
216+
}
217+
outVals = outVals32;
218+
}
210219
}
211-
outVals = outVals32;
212220
}
213221
}
214222

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

33
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
44
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
66
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
77
// CHECK-LABEL: @local_load_offset
88
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
9-
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked>
10-
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
9+
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)
10+
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> loc(#loc2)
1111
// This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type.
12-
// CHECK: llvm.sub
13-
// CHECK-NEXT: llvm.getelementptr
14-
// CHECK-SAME: (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
15-
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
12+
// CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0
13+
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
1614
tt.return
1715
}
1816
}
17+
#loc1 = loc("conert_layout":1:0)
18+
#loc2 = loc("local_alloc":2:0)
19+
#loc3 = loc("local_load":3:0)

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,31 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
3434
tt.return
3535
}
3636
}
37+
38+
// -----
39+
40+
// Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16
41+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
42+
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
43+
#dotop1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>
44+
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
45+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
46+
// CHECK-LABEL: small_mfma_tensor_conversions
47+
tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr<f32>, #mfma>) {
48+
// CHECK-NOT: triton_gpu.convert_layout
49+
%0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
50+
// CHECK-4: store {{.*}} vector<4xf16>
51+
%1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop0>
52+
// CHECK-2: load {{.*}} vector<4xf16>
53+
%2 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop1>
54+
// CHECK-8: load {{.*}} vector<1xf16>
55+
%3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #mfma>
56+
// CHECK-4: load {{.*}} vector<4xf16>
57+
%4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma>
58+
59+
%5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma>
60+
// Store result to prevent DCE from removing all conversion related code
61+
%6 = triton_gpu.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory>
62+
tt.return
63+
}
64+
}

test/TritonGPU/combine.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,3 +2649,39 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 :
26492649
tt.return
26502650
}
26512651
}
2652+
2653+
// -----
2654+
2655+
// Minimized reproducer for compiler crash during remove layouts conversions pass:
2656+
// If dot result transformed into tensor with shape smaller than one MFMA instruction size, it triggers various asserts.
2657+
// This is a smoke test that checks that compiler do not crash.
2658+
//
2659+
// CHECK-LABEL: small_tensor_mfma
2660+
2661+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}>
2662+
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
2663+
#mma1 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
2664+
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
2665+
tt.func public @small_tensor_mfma(%arg0: !tt.ptr<f32>) attributes {noinline = false} {
2666+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
2667+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2668+
%cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
2669+
%cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
2670+
%cst_3 = arith.constant dense<1.230000e+02> : tensor<32x16xf32, #mma1>
2671+
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
2672+
%1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
2673+
%2 = "tt.reduce" (%1) ({
2674+
^bb0(%arg1: f32, %arg2: f32):
2675+
%3 = arith.addf %arg1, %arg2 : f32
2676+
tt.reduce.return %3 : f32
2677+
}) {axis = 1 : i32} : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
2678+
%4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked>
2679+
%5 = tt.broadcast %4 : tensor<32x1xf32, #blocked> -> tensor<32x16xf32, #blocked>
2680+
%6 = triton_gpu.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
2681+
%7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1>
2682+
%addr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked>
2683+
%8 = triton_gpu.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked>
2684+
tt.store %addr, %8 : tensor<32x16x!tt.ptr<f32>, #blocked>
2685+
tt.return
2686+
}
2687+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ struct LocalLoadOpConversion
5050
}
5151

5252
private:
53-
// shared -> dot_operand if the result layout is mfma
53+
/// Lower ttg.local_load in dot operand layout if the operand parent layout is
54+
/// MFMA or WMMA.
55+
///
56+
/// \returns value with packed loaded values or empty value if this local_load
57+
/// is not supproted.
5458
Value lowerSharedToDotOperandMMA(
5559
triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor,
5660
const LLVMTypeConverter *typeConverter,
@@ -104,6 +108,8 @@ struct LocalLoadOpConversion
104108
isOuter = K == 1;
105109
Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter,
106110
dotOperandLayout, isOuter);
111+
if (!res)
112+
return failure();
107113
rewriter.replaceOp(op, res);
108114
return success();
109115
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
231231
mfmaInstrK = elemsPerInstr[kDimIdx];
232232
}
233233

234+
if (mfmaInstrNonK > shape[nonKDimIdx] || mfmaInstrK > shape[kDimIdx]) {
235+
// This pattern does not support cases tensor shape is smaller than
236+
// one instruction size, it will be processed by LinearLayout converter
237+
return Value();
238+
}
239+
234240
auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx);
235241
auto numRepNonK = numReps[nonKDimIdx];
236242
auto numRepK = numReps[kDimIdx];

0 commit comments

Comments
 (0)