Skip to content

Commit 12aa288

Browse files
authored
[BACKEND] Extended combiner regarding dot scaled ops (#9616)
When using tl.dot_scaled, changing the code from an explicit accumulator to Python's `+=` causes a big change in how many registers are used. In our tests, the += version uses many more registers. This leads to lower occupancy, more pressure on memory bandwidth, and register spills. ## Version A (explicit acc=acc) — uses fewer registers ```python acc = tl.dot_scaled( a, a_scale, A_FMT, b, b_scale, B_FMT, acc=acc, out_dtype=tl.float32, ) ``` The generated `.amdgcn` code shows: ```asm .vgpr_count: 186 .vgpr_spill_count: 0 ``` - Much better performance ### Version B (+=) — uses more registers ```python acc += tl.dot_scaled( a, a_scale, A_FMT, b, b_scale, B_FMT, out_dtype=tl.float32, ) ``` The generated .amdgcn code shows: ```asm .vgpr_count: 256 .vgpr_spill_count: 45 ``` - Much worse performance ### Expected behavior Both versions do the same thing logically, so they should produce similar compiled code and use about the same number of registers. ### Comparison with tl.dot This problem does not happen with tl.dot. In that case, the compiler correctly detects the accumulation pattern and merges it, which avoids extra temporary values and keeps register usage low.
1 parent 756afc0 commit 12aa288

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,18 @@ class RankedReduceDescriptorLoads : public mlir::OpRewritePattern<ReshapeOp> {
231231
}
232232
};
233233

234-
template <typename OpTy>
235-
class CombineDotAddPattern : public mlir::OpRewritePattern<OpTy> {
234+
template <typename DotOpType, typename AddOpType>
235+
class CombineDotAddPattern : public mlir::OpRewritePattern<AddOpType> {
236236
public:
237-
using OpRewritePattern<OpTy>::OpRewritePattern;
237+
using OpRewritePattern<AddOpType>::OpRewritePattern;
238238

239239
mlir::LogicalResult
240-
matchAndRewrite(OpTy addOp, mlir::PatternRewriter &rewriter) const override {
241-
auto dotOp = addOp.getRhs().template getDefiningOp<DotOp>();
240+
matchAndRewrite(AddOpType addOp,
241+
mlir::PatternRewriter &rewriter) const override {
242+
auto dotOp = addOp.getRhs().template getDefiningOp<DotOpType>();
242243
bool isDotLHS = false;
243244
if (!dotOp) {
244-
dotOp = addOp.getLhs().template getDefiningOp<DotOp>();
245+
dotOp = addOp.getLhs().template getDefiningOp<DotOpType>();
245246
if (!dotOp) {
246247
return failure();
247248
}
@@ -252,7 +253,8 @@ class CombineDotAddPattern : public mlir::OpRewritePattern<OpTy> {
252253
}
253254
if (!isZero(dotOp.getC()))
254255
return failure();
255-
if constexpr (std::is_same_v<OpTy, arith::AddFOp>) {
256+
if constexpr (std::is_same_v<DotOpType, DotOp> &&
257+
std::is_same_v<AddOpType, arith::AddFOp>) {
256258
if (dotOp.getMaxNumImpreciseAcc() != 0) {
257259
return failure();
258260
}
@@ -270,8 +272,10 @@ class CombineDotAddPattern : public mlir::OpRewritePattern<OpTy> {
270272
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
271273
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
272274
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
273-
using CombineDotAddIPattern = CombineDotAddPattern<arith::AddIOp>;
274-
using CombineDotAddFPattern = CombineDotAddPattern<arith::AddFOp>;
275+
using CombineDotAddIPattern = CombineDotAddPattern<DotOp, arith::AddIOp>;
276+
using CombineDotAddFPattern = CombineDotAddPattern<DotOp, arith::AddFOp>;
277+
using CombineDotScaledAddFPattern =
278+
CombineDotAddPattern<DotScaledOp, arith::AddFOp>;
275279

276280
} // anonymous namespace
277281

@@ -284,6 +288,7 @@ class CombineOpsPass : public impl::TritonCombineOpsBase<CombineOpsPass> {
284288

285289
patterns.add<CombineDotAddIPattern>(context);
286290
patterns.add<CombineDotAddFPattern>(context);
291+
patterns.add<CombineDotScaledAddFPattern>(context);
287292
patterns.add<CombineSelectMaskedLoadPattern>(context);
288293
patterns.add<CombineAddPtrPattern>(context);
289294
patterns.add<CombineBroadcastMulReducePattern>(context);

test/Triton/combine.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,30 @@ tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>) {
4343
}
4444

4545

46+
// CHECK-LABEL: @test_combine_scale_dot_add_pattern
47+
tt.func @test_combine_scale_dot_add_pattern() -> (tensor<128x128xf32>) {
48+
// CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf8E5M2>
49+
// CHECK-DAG: %[[sa:.*]] = arith.constant dense<1> : tensor<128x4xi8>
50+
// CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf8E5M2>
51+
// CHECK-DAG: %[[sb:.*]] = arith.constant dense<2> : tensor<128x4xi8>
52+
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
53+
%a = arith.constant dense<1.0> : tensor<128x128xf8E5M2>
54+
%sa = arith.constant dense<1> : tensor<128x4xi8>
55+
%b = arith.constant dense<2.0> : tensor<128x128xf8E5M2>
56+
%sb = arith.constant dense<2> : tensor<128x4xi8>
57+
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
58+
%d = arith.constant dense<3.0> : tensor<128x128xf32>
59+
60+
%dot_out = tt.dot_scaled %a scale %sa, %b scale %sb, %zero lhs = e5m2 rhs = e5m2 {fastMath = false}
61+
: tensor<128x128xf8E5M2>, tensor<128x4xi8> * tensor<128x128xf8E5M2>, tensor<128x4xi8> -> tensor<128x128xf32>
62+
63+
// CHECK-NEXT: %[[res:.*]] = tt.dot_scaled %[[a]] scale %[[sa]], %[[b]] scale %[[sb]], %[[d]] lhs = e5m2 rhs = e5m2 {fastMath = false} : tensor<128x128xf8E5M2>, tensor<128x4xi8> * tensor<128x128xf8E5M2>, tensor<128x4xi8> -> tensor<128x128xf32>
64+
// CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32>
65+
%res = arith.addf %dot_out, %d : tensor<128x128xf32>
66+
tt.return %res : tensor<128x128xf32>
67+
}
68+
69+
4670
// CHECK-LABEL: @test_combine_dot_add_rev_pattern
4771
tt.func @test_combine_dot_add_rev_pattern() -> (tensor<128x128xf32>) {
4872
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>

0 commit comments

Comments
 (0)