Skip to content

Commit 4a51193

Browse files
chore: fix FloatType building and refactor related export logic
The way `FloatType`s are built was changed upstream: instead of `FloatType::getF32`, we now use `Float32Type::get`. This commit adapts the usages of the previous pattern. While touching that code, the commit also refactors both `exportFloatType` and `exportIntegerType` to use a `switch` statement on the type width instead of constructing all possible MLIR types and comparing linearly against them, which is less efficient. Signed-off-by: Ingo Müller <[email protected]>
1 parent 02f8ff3 commit 4a51193

File tree

2 files changed

+28
-35
lines changed

2 files changed

+28
-35
lines changed

lib/Target/SubstraitPB/Export.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
101101
return any;
102102
}
103103

104-
std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
104+
/// Function that export `IntegerType`'s to the corresponding Substrait types.
105+
std::unique_ptr<proto::Type> exportIntegerType(IntegerType intType,
105106
MLIRContext *context) {
106-
// Function that handles `IntegerType`'s.
107+
assert(intType.isSigned() && "only signed integer types supported");
107108

108-
// Handle SI1.
109-
auto si1 = IntegerType::get(context, 1, IntegerType::Signed);
110-
if (mlirType == si1) {
109+
switch (intType.getWidth()) {
110+
case 1: { // Handle SI1.
111111
// TODO(ingomueller): support other nullability modes.
112112
auto i1Type = std::make_unique<proto::Type::Boolean>();
113113
i1Type->set_nullability(
@@ -118,9 +118,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
118118
return type;
119119
}
120120

121-
// Handle SI8.
122-
auto si8 = IntegerType::get(context, 8, IntegerType::Signed);
123-
if (mlirType == si8) {
121+
case 8: { // Handle SI8.
124122
// TODO(ingomueller): support other nullability modes.
125123
auto i8Type = std::make_unique<proto::Type::I8>();
126124
i8Type->set_nullability(
@@ -131,9 +129,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
131129
return type;
132130
}
133131

134-
// Handle SI6.
135-
auto si16 = IntegerType::get(context, 16, IntegerType::Signed);
136-
if (mlirType == si16) {
132+
case 16: { // Handle SI16.
137133
// TODO(ingomueller): support other nullability modes.
138134
auto i16Type = std::make_unique<proto::Type::I16>();
139135
i16Type->set_nullability(
@@ -144,9 +140,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
144140
return type;
145141
}
146142

147-
// Handle SI32.
148-
auto si32 = IntegerType::get(context, 32, IntegerType::Signed);
149-
if (mlirType == si32) {
143+
case 32: { // Handle SI32.
150144
// TODO(ingomueller): support other nullability modes.
151145
auto i32Type = std::make_unique<proto::Type::I32>();
152146
i32Type->set_nullability(
@@ -157,9 +151,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
157151
return type;
158152
}
159153

160-
// Handle SI64.
161-
auto si64 = IntegerType::get(context, 64, IntegerType::Signed);
162-
if (mlirType == si64) {
154+
case 64: { // Handle SI64.
163155
// TODO(ingomueller): support other nullability modes.
164156
auto i64Type = std::make_unique<proto::Type::I64>();
165157
i64Type->set_nullability(
@@ -170,16 +162,17 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
170162
return type;
171163
}
172164

173-
llvm_unreachable("We should have handled all integer types.");
165+
default:
166+
llvm_unreachable("We should have handled all integer types.");
167+
}
174168
}
175169

176-
std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
170+
/// Function that export `FloatType`'s to the corresponding Substrait types.
171+
std::unique_ptr<proto::Type> exportFloatType(FloatType floatType,
177172
MLIRContext *context) {
178-
// Function that handles `FloatType`'s.
179173

180-
// Handle FP32.
181-
auto fp32 = FloatType::getF32(context);
182-
if (mlirType == fp32) {
174+
switch (floatType.getWidth()) {
175+
case 32: { // Handle FP32.
183176
// TODO(ingomueller): support other nullability modes.
184177
auto fp32Type = std::make_unique<proto::Type::FP32>();
185178
fp32Type->set_nullability(
@@ -190,9 +183,7 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
190183
return type;
191184
}
192185

193-
// Handle FP64.
194-
auto fp64 = FloatType::getF64(context);
195-
if (mlirType == fp64) {
186+
case 64: { // Handle FP64.
196187
// TODO(ingomueller): support other nullability modes.
197188
auto fp64Type = std::make_unique<proto::Type::FP64>();
198189
fp64Type->set_nullability(
@@ -203,21 +194,23 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
203194
return type;
204195
}
205196

206-
llvm_unreachable("We should have handled all float types.");
197+
default:
198+
llvm_unreachable("We should have handled all float types.");
199+
}
207200
}
208201

209202
FailureOr<std::unique_ptr<proto::Type>>
210203
SubstraitExporter::exportType(Location loc, mlir::Type mlirType) {
211204
MLIRContext *context = mlirType.getContext();
212205

213206
// Handle `IntegerType`'s.
214-
if (mlir::isa<IntegerType>(mlirType)) {
215-
return exportIntegerType(mlirType, context);
207+
if (auto intType = mlir::dyn_cast<IntegerType>(mlirType)) {
208+
return exportIntegerType(intType, context);
216209
}
217210

218211
// Handle `FloatType`'s.
219-
if (mlir::isa<FloatType>(mlirType)) {
220-
return exportFloatType(mlirType, context);
212+
if (auto floatType = mlir::dyn_cast<FloatType>(mlirType)) {
213+
return exportFloatType(floatType, context);
221214
}
222215

223216
// Handle String.

lib/Target/SubstraitPB/Import.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
114114
case proto::Type::kI64:
115115
return IntegerType::get(context, 64, IntegerType::Signed);
116116
case proto::Type::kFp32:
117-
return FloatType::getF32(context);
117+
return Float32Type::get(context);
118118
case proto::Type::kFp64:
119-
return FloatType::getF64(context);
119+
return Float64Type::get(context);
120120
case proto::Type::kString:
121121
return StringType::get(context);
122122
case proto::Type::kBinary:
@@ -344,11 +344,11 @@ importLiteral(ImplicitLocOpBuilder builder,
344344
return builder.create<LiteralOp>(attr);
345345
}
346346
case Expression::Literal::LiteralTypeCase::kFp32: {
347-
auto attr = FloatAttr::get(FloatType::getF32(context), message.fp32());
347+
auto attr = FloatAttr::get(Float32Type::get(context), message.fp32());
348348
return builder.create<LiteralOp>(attr);
349349
}
350350
case Expression::Literal::LiteralTypeCase::kFp64: {
351-
auto attr = FloatAttr::get(FloatType::getF64(context), message.fp64());
351+
auto attr = FloatAttr::get(Float64Type::get(context), message.fp64());
352352
return builder.create<LiteralOp>(attr);
353353
}
354354
case Expression::Literal::LiteralTypeCase::kString: {

0 commit comments

Comments
 (0)