@@ -50,10 +50,10 @@ template <typename Op>
5050struct ScalarOpToLibmCall : public OpRewritePattern <Op> {
5151public:
5252 using OpRewritePattern<Op>::OpRewritePattern;
53- ScalarOpToLibmCall (MLIRContext *context, StringRef floatFunc ,
54- StringRef doubleFunc)
55- : OpRewritePattern<Op>(context), floatFunc(floatFunc),
56- doubleFunc (doubleFunc){};
53+ ScalarOpToLibmCall (MLIRContext *context, PatternBenefit benefit ,
54+ StringRef floatFunc, StringRef doubleFunc)
55+ : OpRewritePattern<Op>(context, benefit ), floatFunc(floatFunc),
56+ doubleFunc (doubleFunc) {};
5757
5858 LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final ;
5959
@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
6262};
6363
6464template <typename OpTy>
65- void populatePatternsForOp (RewritePatternSet &patterns, MLIRContext *ctx,
66- StringRef floatFunc, StringRef doubleFunc) {
67- patterns.add <VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68- patterns.add <ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
65+ void populatePatternsForOp (RewritePatternSet &patterns, PatternBenefit benefit,
66+ MLIRContext *ctx, StringRef floatFunc,
67+ StringRef doubleFunc) {
68+ patterns.add <VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69+ patterns.add <ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
6970}
7071
7172} // namespace
@@ -159,42 +160,54 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
159160 return success ();
160161}
161162
162- void mlir::populateMathToLibmConversionPatterns (RewritePatternSet &patterns) {
163+ void mlir::populateMathToLibmConversionPatterns (RewritePatternSet &patterns,
164+ PatternBenefit benefit) {
163165 MLIRContext *ctx = patterns.getContext ();
164166
165- populatePatternsForOp<math::AbsFOp>(patterns, ctx, " fabsf" , " fabs" );
166- populatePatternsForOp<math::AcosOp>(patterns, ctx, " acosf" , " acos" );
167- populatePatternsForOp<math::AcoshOp>(patterns, ctx, " acoshf" , " acosh" );
168- populatePatternsForOp<math::AsinOp>(patterns, ctx, " asinf" , " asin" );
169- populatePatternsForOp<math::AsinhOp>(patterns, ctx, " asinhf" , " asinh" );
170- populatePatternsForOp<math::Atan2Op>(patterns, ctx, " atan2f" , " atan2" );
171- populatePatternsForOp<math::AtanOp>(patterns, ctx, " atanf" , " atan" );
172- populatePatternsForOp<math::AtanhOp>(patterns, ctx, " atanhf" , " atanh" );
173- populatePatternsForOp<math::CbrtOp>(patterns, ctx, " cbrtf" , " cbrt" );
174- populatePatternsForOp<math::CeilOp>(patterns, ctx, " ceilf" , " ceil" );
175- populatePatternsForOp<math::CosOp>(patterns, ctx, " cosf" , " cos" );
176- populatePatternsForOp<math::CoshOp>(patterns, ctx, " coshf" , " cosh" );
177- populatePatternsForOp<math::ErfOp>(patterns, ctx, " erff" , " erf" );
178- populatePatternsForOp<math::ExpOp>(patterns, ctx, " expf" , " exp" );
179- populatePatternsForOp<math::Exp2Op>(patterns, ctx, " exp2f" , " exp2" );
180- populatePatternsForOp<math::ExpM1Op>(patterns, ctx, " expm1f" , " expm1" );
181- populatePatternsForOp<math::FloorOp>(patterns, ctx, " floorf" , " floor" );
182- populatePatternsForOp<math::FmaOp>(patterns, ctx, " fmaf" , " fma" );
183- populatePatternsForOp<math::LogOp>(patterns, ctx, " logf" , " log" );
184- populatePatternsForOp<math::Log2Op>(patterns, ctx, " log2f" , " log2" );
185- populatePatternsForOp<math::Log10Op>(patterns, ctx, " log10f" , " log10" );
186- populatePatternsForOp<math::Log1pOp>(patterns, ctx, " log1pf" , " log1p" );
187- populatePatternsForOp<math::PowFOp>(patterns, ctx, " powf" , " pow" );
188- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, " roundevenf" ,
167+ populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, " fabsf" , " fabs" );
168+ populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, " acosf" , " acos" );
169+ populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, " acoshf" ,
170+ " acosh" );
171+ populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, " asinf" , " asin" );
172+ populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, " asinhf" ,
173+ " asinh" );
174+ populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, " atan2f" ,
175+ " atan2" );
176+ populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, " atanf" , " atan" );
177+ populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, " atanhf" ,
178+ " atanh" );
179+ populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, " cbrtf" , " cbrt" );
180+ populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, " ceilf" , " ceil" );
181+ populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, " cosf" , " cos" );
182+ populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, " coshf" , " cosh" );
183+ populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, " erff" , " erf" );
184+ populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, " expf" , " exp" );
185+ populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, " exp2f" , " exp2" );
186+ populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, " expm1f" ,
187+ " expm1" );
188+ populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, " floorf" ,
189+ " floor" );
190+ populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, " fmaf" , " fma" );
191+ populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, " logf" , " log" );
192+ populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, " log2f" , " log2" );
193+ populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, " log10f" ,
194+ " log10" );
195+ populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, " log1pf" ,
196+ " log1p" );
197+ populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, " powf" , " pow" );
198+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, " roundevenf" ,
189199 " roundeven" );
190- populatePatternsForOp<math::RoundOp>(patterns, ctx, " roundf" , " round" );
191- populatePatternsForOp<math::SinOp>(patterns, ctx, " sinf" , " sin" );
192- populatePatternsForOp<math::SinhOp>(patterns, ctx, " sinhf" , " sinh" );
193- populatePatternsForOp<math::SqrtOp>(patterns, ctx, " sqrtf" , " sqrt" );
194- populatePatternsForOp<math::RsqrtOp>(patterns, ctx, " rsqrtf" , " rsqrt" );
195- populatePatternsForOp<math::TanOp>(patterns, ctx, " tanf" , " tan" );
196- populatePatternsForOp<math::TanhOp>(patterns, ctx, " tanhf" , " tanh" );
197- populatePatternsForOp<math::TruncOp>(patterns, ctx, " truncf" , " trunc" );
200+ populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, " roundf" ,
201+ " round" );
202+ populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, " sinf" , " sin" );
203+ populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, " sinhf" , " sinh" );
204+ populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, " sqrtf" , " sqrt" );
205+ populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, " rsqrtf" ,
206+ " rsqrt" );
207+ populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, " tanf" , " tan" );
208+ populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, " tanhf" , " tanh" );
209+ populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, " truncf" ,
210+ " trunc" );
198211}
199212
200213namespace {
0 commit comments