Skip to content

Commit 7416ffc

Browse files
authored
[pick][release/3.5.x] Fold layout conversion for TMEM Store to fix perf drop for flex attn (#8366)
Picking #8353 from main branch. This is to fix perf drop for flex attention. See more details in original PR and #8328
1 parent 2766408 commit 7416ffc

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,31 @@ bool isConvertTrivial(ConvertLayoutOp op) {
7272
// Canonicalizer
7373
//===----------------------------------------------------------------------===//
7474

75+
// tmem_store(cvt) -> tmem_store
76+
struct CanonicalizeConvertFromTMEMStore
77+
: public mlir::OpRewritePattern<nvidia_gpu::TMEMStoreOp> {
78+
using OpRewritePattern::OpRewritePattern;
79+
80+
mlir::LogicalResult
81+
matchAndRewrite(nvidia_gpu::TMEMStoreOp op,
82+
PatternRewriter &rewriter) const override {
83+
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
84+
if (!convert)
85+
return failure();
86+
87+
// bail for incompatible layouts
88+
auto cvtSrcType = convert.getSrc().getType();
89+
if (!nvidia_gpu::isDistributedLayoutTMemCompatible(
90+
op.getOperation(), cvtSrcType, op.getDst().getType())) {
91+
return failure();
92+
}
93+
94+
rewriter.modifyOpInPlace(
95+
op, [&]() { op.getSrcMutable().assign(convert.getSrc()); });
96+
return mlir::success();
97+
}
98+
};
99+
75100
// reshape(cvt) -> reshape
76101
struct CanonicalizeConvertFromReshape
77102
: public mlir::OpRewritePattern<triton::ReshapeOp> {
@@ -371,6 +396,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
371396
patterns.add<CanonicalizeConvertFromAlloc>(context);
372397
patterns.add<CanonicalizeConvertFromLocalStore>(context);
373398
patterns.add<CanonicalizeConvertFromSplit>(context);
399+
patterns.add<CanonicalizeConvertFromTMEMStore>(context);
374400
}
375401

376402
LogicalResult Fp4ToFpOp::verify() {

test/TritonGPU/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ tt.func @test_canonicalize_convert_local_load(%arg0: !ttg.async.token) -> tensor
124124

125125
// -----
126126

127+
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
128+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}>
129+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = false>
130+
// CHECK-LABEL: test_canonicalize_convert_tmem_store
131+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
132+
tt.func @test_canonicalize_convert_tmem_store(
133+
%arg0: tensor<128x64xbf16, #linear>,
134+
%arg1: !ttg.memdesc<128x64xbf16, #tmem, #ttng.tensor_memory, mutable>
135+
) {
136+
%true = arith.constant true
137+
// CHECK-NOT: ttg.convert_layout
138+
%1 = ttg.convert_layout %arg0 : tensor<128x64xbf16, #linear> -> tensor<128x64xbf16, #blocked>
139+
// CHECK: ttng.tmem_store %{{.*}} : tensor<128x64xbf16, #linear> ->
140+
ttng.tmem_store %1, %arg1, %true : tensor<128x64xbf16, #blocked> -> !ttg.memdesc<128x64xbf16, #tmem, #ttng.tensor_memory, mutable>
141+
tt.return
142+
}
143+
}
144+
145+
// -----
146+
127147
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
128148
#shared = #ttg.swizzled_shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
129149
#smem = #ttg.shared_memory

0 commit comments

Comments
 (0)