Skip to content

Commit 3c373fa

Browse files
matthias-springervinay-deshmukh
authored andcommitted
[mlir][arith] Fix arith.select lowering after llvm#166513 (llvm#166692)
llvm#166513 broke the lowering of `arith.select` with unsupported FP4 types. For this op, it is fine to convert to `i4`.
1 parent 4688a36 commit 3c373fa

File tree

3 files changed

+57
-30
lines changed

3 files changed

+57
-30
lines changed

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
8686
/// ArrayRef<NamedAttribute>.
8787
template <typename SourceOp, typename TargetOp,
8888
template <typename, typename> typename AttrConvert =
89-
AttrConvertPassThrough>
89+
AttrConvertPassThrough,
90+
bool FailOnUnsupportedFP = false>
9091
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
9192
public:
9293
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
123124
"unsupported floating point type");
124125
return success();
125126
};
126-
for (Value operand : op->getOperands())
127-
if (failed(checkType(operand)))
127+
if (FailOnUnsupportedFP) {
128+
for (Value operand : op->getOperands())
129+
if (failed(checkType(operand)))
130+
return failure();
131+
if (failed(checkType(op->getResult(0))))
128132
return failure();
129-
if (failed(checkType(op->getResult(0))))
130-
return failure();
133+
}
131134

132135
// Determine attributes for the target op
133136
AttrConvert<SourceOp, TargetOp> attrConvert(op);

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,23 @@ namespace {
3636
/// attribute.
3737
template <typename SourceOp, typename TargetOp, bool Constrained,
3838
template <typename, typename> typename AttrConvert =
39-
AttrConvertPassThrough>
39+
AttrConvertPassThrough,
40+
bool FailOnUnsupportedFP = false>
4041
struct ConstrainedVectorConvertToLLVMPattern
41-
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
42-
using VectorConvertToLLVMPattern<SourceOp, TargetOp,
43-
AttrConvert>::VectorConvertToLLVMPattern;
42+
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
43+
FailOnUnsupportedFP> {
44+
using VectorConvertToLLVMPattern<
45+
SourceOp, TargetOp, AttrConvert,
46+
FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
4447

4548
LogicalResult
4649
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
4750
ConversionPatternRewriter &rewriter) const override {
4851
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
4952
return failure();
50-
return VectorConvertToLLVMPattern<SourceOp, TargetOp,
51-
AttrConvert>::matchAndRewrite(op, adaptor,
52-
rewriter);
53+
return VectorConvertToLLVMPattern<
54+
SourceOp, TargetOp, AttrConvert,
55+
FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
5356
}
5457
};
5558

@@ -78,7 +81,8 @@ struct IdentityBitcastLowering final
7881

7982
using AddFOpLowering =
8083
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
81-
arith::AttrConvertFastMathToLLVM>;
84+
arith::AttrConvertFastMathToLLVM,
85+
/*FailOnUnsupportedFP=*/true>;
8286
using AddIOpLowering =
8387
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
8488
arith::AttrConvertOverflowToLLVM>;
@@ -87,53 +91,67 @@ using BitcastOpLowering =
8791
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
8892
using DivFOpLowering =
8993
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
90-
arith::AttrConvertFastMathToLLVM>;
94+
arith::AttrConvertFastMathToLLVM,
95+
/*FailOnUnsupportedFP=*/true>;
9196
using DivSIOpLowering =
9297
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
9398
using DivUIOpLowering =
9499
VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
95-
using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
100+
using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
101+
AttrConvertPassThrough,
102+
/*FailOnUnsupportedFP=*/true>;
96103
using ExtSIOpLowering =
97104
VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
98105
using ExtUIOpLowering =
99106
VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
100107
using FPToSIOpLowering =
101-
VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
108+
VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
109+
AttrConvertPassThrough,
110+
/*FailOnUnsupportedFP=*/true>;
102111
using FPToUIOpLowering =
103-
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
112+
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
113+
AttrConvertPassThrough,
114+
/*FailOnUnsupportedFP=*/true>;
104115
using MaximumFOpLowering =
105116
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
106-
arith::AttrConvertFastMathToLLVM>;
117+
arith::AttrConvertFastMathToLLVM,
118+
/*FailOnUnsupportedFP=*/true>;
107119
using MaxNumFOpLowering =
108120
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
109-
arith::AttrConvertFastMathToLLVM>;
121+
arith::AttrConvertFastMathToLLVM,
122+
/*FailOnUnsupportedFP=*/true>;
110123
using MaxSIOpLowering =
111124
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
112125
using MaxUIOpLowering =
113126
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
114127
using MinimumFOpLowering =
115128
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
116-
arith::AttrConvertFastMathToLLVM>;
129+
arith::AttrConvertFastMathToLLVM,
130+
/*FailOnUnsupportedFP=*/true>;
117131
using MinNumFOpLowering =
118132
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
119-
arith::AttrConvertFastMathToLLVM>;
133+
arith::AttrConvertFastMathToLLVM,
134+
/*FailOnUnsupportedFP=*/true>;
120135
using MinSIOpLowering =
121136
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
122137
using MinUIOpLowering =
123138
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
124139
using MulFOpLowering =
125140
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
126-
arith::AttrConvertFastMathToLLVM>;
141+
arith::AttrConvertFastMathToLLVM,
142+
/*FailOnUnsupportedFP=*/true>;
127143
using MulIOpLowering =
128144
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
129145
arith::AttrConvertOverflowToLLVM>;
130146
using NegFOpLowering =
131147
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
132-
arith::AttrConvertFastMathToLLVM>;
148+
arith::AttrConvertFastMathToLLVM,
149+
/*FailOnUnsupportedFP=*/true>;
133150
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
134151
using RemFOpLowering =
135152
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
136-
arith::AttrConvertFastMathToLLVM>;
153+
arith::AttrConvertFastMathToLLVM,
154+
/*FailOnUnsupportedFP=*/true>;
137155
using RemSIOpLowering =
138156
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
139157
using RemUIOpLowering =
@@ -151,21 +169,25 @@ using SIToFPOpLowering =
151169
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
152170
using SubFOpLowering =
153171
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
154-
arith::AttrConvertFastMathToLLVM>;
172+
arith::AttrConvertFastMathToLLVM,
173+
/*FailOnUnsupportedFP=*/true>;
155174
using SubIOpLowering =
156175
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
157176
arith::AttrConvertOverflowToLLVM>;
158177
using TruncFOpLowering =
159178
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
160-
false>;
179+
false, AttrConvertPassThrough,
180+
/*FailOnUnsupportedFP=*/true>;
161181
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162182
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
163-
arith::AttrConverterConstrainedFPToLLVM>;
183+
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
164184
using TruncIOpLowering =
165185
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
166186
arith::AttrConvertOverflowToLLVM>;
167187
using UIToFPOpLowering =
168-
VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
188+
VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
189+
AttrConvertPassThrough,
190+
/*FailOnUnsupportedFP=*/true>;
169191
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
170192

171193
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
754754
// CHECK: arith.addf {{.*}} : f4E2M1FN
755755
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
756756
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
757-
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
757+
// CHECK: llvm.select {{.*}} : i1, i4
758+
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
758759
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
759760
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
760761
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
761-
return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
762+
%3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
763+
return
762764
}
763765

764766
// -----

0 commit comments

Comments
 (0)