Skip to content

Commit 9c4af87

Browse files
chore: fix FloatType building and refactor related export logic
The way `FloatType`s are built was changed upstream: instead of `FloatType::getF32`, we now use `Float32Type::get`. This commit adapts the usages of the previous pattern. While touching that code, the commit also refactors both `exportFloatType` and `exportIntegerType` to use a `switch` statement on the type width instead of constructing all possible MLIR types and comparing linearly against them, which is less efficient. Signed-off-by: Ingo Müller <[email protected]>
1 parent d2d686c commit 9c4af87

File tree

2 files changed

+28
-35
lines changed

2 files changed

+28
-35
lines changed

lib/Target/SubstraitPB/Export.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
150150
return any;
151151
}
152152

153-
std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
153+
/// Function that export `IntegerType`'s to the corresponding Substrait types.
154+
std::unique_ptr<proto::Type> exportIntegerType(IntegerType intType,
154155
MLIRContext *context) {
155-
// Function that handles `IntegerType`'s.
156+
assert(intType.isSigned() && "only signed integer types supported");
156157

157-
// Handle SI1.
158-
auto si1 = IntegerType::get(context, 1, IntegerType::Signed);
159-
if (mlirType == si1) {
158+
switch (intType.getWidth()) {
159+
case 1: { // Handle SI1.
160160
// TODO(ingomueller): support other nullability modes.
161161
auto i1Type = std::make_unique<proto::Type::Boolean>();
162162
i1Type->set_nullability(
@@ -167,9 +167,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
167167
return type;
168168
}
169169

170-
// Handle SI8.
171-
auto si8 = IntegerType::get(context, 8, IntegerType::Signed);
172-
if (mlirType == si8) {
170+
case 8: { // Handle SI8.
173171
// TODO(ingomueller): support other nullability modes.
174172
auto i8Type = std::make_unique<proto::Type::I8>();
175173
i8Type->set_nullability(
@@ -180,9 +178,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
180178
return type;
181179
}
182180

183-
// Handle SI6.
184-
auto si16 = IntegerType::get(context, 16, IntegerType::Signed);
185-
if (mlirType == si16) {
181+
case 16: { // Handle SI16.
186182
// TODO(ingomueller): support other nullability modes.
187183
auto i16Type = std::make_unique<proto::Type::I16>();
188184
i16Type->set_nullability(
@@ -193,9 +189,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
193189
return type;
194190
}
195191

196-
// Handle SI32.
197-
auto si32 = IntegerType::get(context, 32, IntegerType::Signed);
198-
if (mlirType == si32) {
192+
case 32: { // Handle SI32.
199193
// TODO(ingomueller): support other nullability modes.
200194
auto i32Type = std::make_unique<proto::Type::I32>();
201195
i32Type->set_nullability(
@@ -206,9 +200,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
206200
return type;
207201
}
208202

209-
// Handle SI64.
210-
auto si64 = IntegerType::get(context, 64, IntegerType::Signed);
211-
if (mlirType == si64) {
203+
case 64: { // Handle SI64.
212204
// TODO(ingomueller): support other nullability modes.
213205
auto i64Type = std::make_unique<proto::Type::I64>();
214206
i64Type->set_nullability(
@@ -219,16 +211,17 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
219211
return type;
220212
}
221213

222-
llvm_unreachable("We should have handled all integer types.");
214+
default:
215+
llvm_unreachable("We should have handled all integer types.");
216+
}
223217
}
224218

225-
std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
219+
/// Function that export `FloatType`'s to the corresponding Substrait types.
220+
std::unique_ptr<proto::Type> exportFloatType(FloatType floatType,
226221
MLIRContext *context) {
227-
// Function that handles `FloatType`'s.
228222

229-
// Handle FP32.
230-
auto fp32 = FloatType::getF32(context);
231-
if (mlirType == fp32) {
223+
switch (floatType.getWidth()) {
224+
case 32: { // Handle FP32.
232225
// TODO(ingomueller): support other nullability modes.
233226
auto fp32Type = std::make_unique<proto::Type::FP32>();
234227
fp32Type->set_nullability(
@@ -239,9 +232,7 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
239232
return type;
240233
}
241234

242-
// Handle FP64.
243-
auto fp64 = FloatType::getF64(context);
244-
if (mlirType == fp64) {
235+
case 64: { // Handle FP64.
245236
// TODO(ingomueller): support other nullability modes.
246237
auto fp64Type = std::make_unique<proto::Type::FP64>();
247238
fp64Type->set_nullability(
@@ -252,21 +243,23 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
252243
return type;
253244
}
254245

255-
llvm_unreachable("We should have handled all float types.");
246+
default:
247+
llvm_unreachable("We should have handled all float types.");
248+
}
256249
}
257250

258251
FailureOr<std::unique_ptr<proto::Type>>
259252
SubstraitExporter::exportType(Location loc, mlir::Type mlirType) {
260253
MLIRContext *context = mlirType.getContext();
261254

262255
// Handle `IntegerType`'s.
263-
if (mlir::isa<IntegerType>(mlirType)) {
264-
return exportIntegerType(mlirType, context);
256+
if (auto intType = mlir::dyn_cast<IntegerType>(mlirType)) {
257+
return exportIntegerType(intType, context);
265258
}
266259

267260
// Handle `FloatType`'s.
268-
if (mlir::isa<FloatType>(mlirType)) {
269-
return exportFloatType(mlirType, context);
261+
if (auto floatType = mlir::dyn_cast<FloatType>(mlirType)) {
262+
return exportFloatType(floatType, context);
270263
}
271264

272265
// Handle String.

lib/Target/SubstraitPB/Import.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
186186
case proto::Type::kI64:
187187
return IntegerType::get(context, 64, IntegerType::Signed);
188188
case proto::Type::kFp32:
189-
return FloatType::getF32(context);
189+
return Float32Type::get(context);
190190
case proto::Type::kFp64:
191-
return FloatType::getF64(context);
191+
return Float64Type::get(context);
192192
case proto::Type::kString:
193193
return StringType::get(context);
194194
case proto::Type::kBinary:
@@ -560,11 +560,11 @@ importLiteral(ImplicitLocOpBuilder builder,
560560
return builder.create<LiteralOp>(attr);
561561
}
562562
case Expression::Literal::LiteralTypeCase::kFp32: {
563-
auto attr = FloatAttr::get(FloatType::getF32(context), message.fp32());
563+
auto attr = FloatAttr::get(Float32Type::get(context), message.fp32());
564564
return builder.create<LiteralOp>(attr);
565565
}
566566
case Expression::Literal::LiteralTypeCase::kFp64: {
567-
auto attr = FloatAttr::get(FloatType::getF64(context), message.fp64());
567+
auto attr = FloatAttr::get(Float64Type::get(context), message.fp64());
568568
return builder.create<LiteralOp>(attr);
569569
}
570570
case Expression::Literal::LiteralTypeCase::kString: {

0 commit comments

Comments
 (0)