Skip to content

Commit 9f30721

Browse files
author
git apple-llvm automerger
committed
Merge commit 'ca1ff80a16f1' from llvm.org/main into next
2 parents 7e4fd47 + ca1ff80 commit 9f30721

File tree

4 files changed

+175
-0
lines changed

4 files changed

+175
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> {
6464
];
6565
}
6666

67+
def MathSincosFusionPass : Pass<"math-sincos-fusion"> {
68+
let summary = "Fuse sin and cos operations.";
69+
let description = [{
70+
Fuse sin and cos operations into a sincos operation.
71+
}];
72+
let dependentDialects = ["math::MathDialect"];
73+
}
74+
6775
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms
33
ExpandOps.cpp
44
ExtendToSupportedTypes.cpp
55
PolynomialApproximation.cpp
6+
SincosFusion.cpp
67
UpliftToFMA.cpp
78

89
ADDITIONAL_HEADER_DIRS
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Math/IR/Math.h"
10+
#include "mlir/Dialect/Math/Transforms/Passes.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::math;
16+
17+
namespace {
18+
19+
/// Fuse a math.sin and math.cos in the same block that use the same operand and
20+
/// have identical fastmath flags into a single math.sincos.
21+
struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
22+
using Base::Base;
23+
24+
LogicalResult matchAndRewrite(math::SinOp sinOp,
25+
PatternRewriter &rewriter) const override {
26+
Value operand = sinOp.getOperand();
27+
mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
28+
29+
math::CosOp cosOp = nullptr;
30+
sinOp->getBlock()->walk([&](math::CosOp op) {
31+
if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
32+
cosOp = op;
33+
return WalkResult::interrupt();
34+
}
35+
return WalkResult::advance();
36+
});
37+
38+
if (!cosOp)
39+
return failure();
40+
41+
Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
42+
: cosOp.getOperation();
43+
rewriter.setInsertionPoint(firstOp);
44+
45+
Type elemType = sinOp.getType();
46+
auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
47+
TypeRange{elemType, elemType}, operand,
48+
sinOp.getFastmathAttr());
49+
50+
rewriter.replaceOp(sinOp, sincos.getSin());
51+
rewriter.replaceOp(cosOp, sincos.getCos());
52+
return success();
53+
}
54+
};
55+
56+
} // namespace
57+
58+
namespace mlir::math {
59+
#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
60+
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
61+
} // namespace mlir::math
62+
63+
namespace {
64+
65+
struct MathSincosFusionPass final
66+
: math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
67+
using MathSincosFusionPassBase::MathSincosFusionPassBase;
68+
69+
void runOnOperation() override {
70+
RewritePatternSet patterns(&getContext());
71+
patterns.add<SincosFusionPattern>(&getContext());
72+
73+
GreedyRewriteConfig config;
74+
if (failed(
75+
applyPatternsGreedily(getOperation(), std::move(patterns), config)))
76+
return signalPassFailure();
77+
}
78+
};
79+
80+
} // namespace
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @sincos_fusion(
4+
// CHECK-SAME: %[[ARG0:.*]]: f32,
5+
// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
6+
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
7+
// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
8+
// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
9+
// CHECK: }
10+
func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
11+
%0 = math.sin %arg0 : f32
12+
%1 = math.cos %arg0 : f32
13+
14+
%2 = math.cos %arg1 : f32
15+
%3 = math.sin %arg1 : f32
16+
17+
func.return %0, %1, %2, %3 : f32, f32, f32, f32
18+
}
19+
20+
func.func private @sink(%arg0 : f32)
21+
22+
// CHECK: func.func private @sink(f32)
23+
// CHECK-LABEL: func.func @sincos_ensure_ssa_dominance(
24+
// CHECK-SAME: %[[ARG0:.*]]: f32,
25+
// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
26+
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
27+
// CHECK: call @sink(%[[VAL_0]]) : (f32) -> ()
28+
// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
29+
// CHECK: call @sink(%[[VAL_3]]) : (f32) -> ()
30+
// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
31+
// CHECK: }
32+
func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
33+
%0 = math.sin %arg0 : f32
34+
func.call @sink(%0) : (f32) -> ()
35+
%1 = math.cos %arg0 : f32
36+
%2 = math.cos %arg1 : f32
37+
func.call @sink(%2) : (f32) -> ()
38+
%3 = math.sin %arg1 : f32
39+
func.return %0, %1, %2, %3 : f32, f32, f32, f32
40+
}
41+
42+
// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf(
43+
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
44+
// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32
45+
// CHECK: %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32
46+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
47+
// CHECK: }
48+
func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) {
49+
%0 = math.sin %arg0 fastmath<contract> : f32
50+
%1 = math.cos %arg0 : f32
51+
func.return %0, %1 : f32, f32
52+
}
53+
54+
// CHECK-LABEL: func.func @sincos_no_fusion_different_block(
55+
// CHECK-SAME: %[[ARG0:.*]]: f32,
56+
// CHECK-SAME: %[[ARG1:.*]]: i1) -> f32 {
57+
// CHECK: %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) {
58+
// CHECK: %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32
59+
// CHECK: scf.yield %[[VAL_1]] : f32
60+
// CHECK: } else {
61+
// CHECK: %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32
62+
// CHECK: scf.yield %[[VAL_2]] : f32
63+
// CHECK: }
64+
// CHECK: return %[[VAL_0]] : f32
65+
// CHECK: }
66+
func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
67+
%0 = scf.if %flag -> f32 {
68+
%s = math.sin %arg0 : f32
69+
scf.yield %s : f32
70+
} else {
71+
%c = math.cos %arg0 : f32
72+
scf.yield %c : f32
73+
}
74+
func.return %0 : f32
75+
}
76+
77+
// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
78+
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
79+
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
80+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
81+
// CHECK: }
82+
func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) {
83+
%0 = math.sin %arg0 fastmath<contract> : f32
84+
%1 = math.cos %arg0 fastmath<contract> : f32
85+
func.return %0, %1 : f32, f32
86+
}

0 commit comments

Comments
 (0)