Skip to content

Commit 340cbc6

Browse files
authored
[BACKEND] Fix a missed transpose optimization during refactor (#5236)
1 parent 16ce143 commit 340cbc6

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,16 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
148148
LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
149149
PatternRewriter &rewriter) const override {
150150
// Match outerCvt(trans(innerCvt(x))).
151-
auto trans = cvtOp.getSrc().getDefiningOp<MemDescTransOp>();
151+
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
152152
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
153153
return failure();
154154

155-
auto srcTy = dyn_cast<RankedTensorType>(trans.getSrc().getType());
155+
RankedTensorType srcTy = trans.getSrc().getType();
156156

157157
if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
158158
srcTy = srcCvt.getSrc().getType();
159159
}
160-
auto sharedLoadTy = cast<RankedTensorType>(cvtOp.getType());
160+
RankedTensorType sharedLoadTy = cvtOp.getType();
161161
auto cvtEncoding =
162162
dyn_cast<DotOperandEncodingAttr>(sharedLoadTy.getEncoding());
163163
if (!cvtEncoding)

test/TritonGPU/dot-operands.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
276276
tt.return %r : tensor<128x64xf32, #mma>
277277
}
278278
}
279+
280+
// -----
281+
282+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
283+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
284+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
285+
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
286+
// CHECK-LABEL: mmav2_reorder_transpose
287+
// CHECK: triton_gpu.local_alloc
288+
// CHECK: triton_gpu.memdesc_trans
289+
// CHECK: triton_gpu.local_load
290+
// CHECK: tt.dot
291+
tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
292+
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked>
293+
%cv = triton_gpu.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
294+
%r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
295+
tt.return %r : tensor<128x64xf32, #mma>
296+
}
297+
}

0 commit comments

Comments
 (0)