@@ -150,13 +150,13 @@ std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
150150 return any;
151151}
152152
153- std::unique_ptr<proto::Type> exportIntegerType (mlir::Type mlirType,
153+ // / Function that export `IntegerType`'s to the corresponding Substrait types.
154+ std::unique_ptr<proto::Type> exportIntegerType (IntegerType intType,
154155 MLIRContext *context) {
155- // Function that handles `IntegerType`'s.
156+ assert (intType. isSigned () && " only signed integer types supported " );
156157
157- // Handle SI1.
158- auto si1 = IntegerType::get (context, 1 , IntegerType::Signed);
159- if (mlirType == si1) {
158+ switch (intType.getWidth ()) {
159+ case 1 : { // Handle SI1.
160160 // TODO(ingomueller): support other nullability modes.
161161 auto i1Type = std::make_unique<proto::Type::Boolean>();
162162 i1Type->set_nullability (
@@ -167,9 +167,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
167167 return type;
168168 }
169169
170- // Handle SI8.
171- auto si8 = IntegerType::get (context, 8 , IntegerType::Signed);
172- if (mlirType == si8) {
170+ case 8 : { // Handle SI8.
173171 // TODO(ingomueller): support other nullability modes.
174172 auto i8Type = std::make_unique<proto::Type::I8>();
175173 i8Type->set_nullability (
@@ -180,9 +178,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
180178 return type;
181179 }
182180
183- // Handle SI6.
184- auto si16 = IntegerType::get (context, 16 , IntegerType::Signed);
185- if (mlirType == si16) {
181+ case 16 : { // Handle SI16.
186182 // TODO(ingomueller): support other nullability modes.
187183 auto i16Type = std::make_unique<proto::Type::I16>();
188184 i16Type->set_nullability (
@@ -193,9 +189,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
193189 return type;
194190 }
195191
196- // Handle SI32.
197- auto si32 = IntegerType::get (context, 32 , IntegerType::Signed);
198- if (mlirType == si32) {
192+ case 32 : { // Handle SI32.
199193 // TODO(ingomueller): support other nullability modes.
200194 auto i32Type = std::make_unique<proto::Type::I32>();
201195 i32Type->set_nullability (
@@ -206,9 +200,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
206200 return type;
207201 }
208202
209- // Handle SI64.
210- auto si64 = IntegerType::get (context, 64 , IntegerType::Signed);
211- if (mlirType == si64) {
203+ case 64 : { // Handle SI64.
212204 // TODO(ingomueller): support other nullability modes.
213205 auto i64Type = std::make_unique<proto::Type::I64>();
214206 i64Type->set_nullability (
@@ -219,16 +211,17 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
219211 return type;
220212 }
221213
222- llvm_unreachable (" We should have handled all integer types." );
214+ default :
215+ llvm_unreachable (" We should have handled all integer types." );
216+ }
223217}
224218
225- std::unique_ptr<proto::Type> exportFloatType (mlir::Type mlirType,
219+ // / Function that export `FloatType`'s to the corresponding Substrait types.
220+ std::unique_ptr<proto::Type> exportFloatType (FloatType floatType,
226221 MLIRContext *context) {
227- // Function that handles `FloatType`'s.
228222
229- // Handle FP32.
230- auto fp32 = FloatType::getF32 (context);
231- if (mlirType == fp32) {
223+ switch (floatType.getWidth ()) {
224+ case 32 : { // Handle FP32.
232225 // TODO(ingomueller): support other nullability modes.
233226 auto fp32Type = std::make_unique<proto::Type::FP32>();
234227 fp32Type->set_nullability (
@@ -239,9 +232,7 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
239232 return type;
240233 }
241234
242- // Handle FP64.
243- auto fp64 = FloatType::getF64 (context);
244- if (mlirType == fp64) {
235+ case 64 : { // Handle FP64.
245236 // TODO(ingomueller): support other nullability modes.
246237 auto fp64Type = std::make_unique<proto::Type::FP64>();
247238 fp64Type->set_nullability (
@@ -252,21 +243,23 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
252243 return type;
253244 }
254245
255- llvm_unreachable (" We should have handled all float types." );
246+ default :
247+ llvm_unreachable (" We should have handled all float types." );
248+ }
256249}
257250
258251FailureOr<std::unique_ptr<proto::Type>>
259252SubstraitExporter::exportType (Location loc, mlir::Type mlirType) {
260253 MLIRContext *context = mlirType.getContext ();
261254
262255 // Handle `IntegerType`'s.
263- if (mlir::isa <IntegerType>(mlirType)) {
264- return exportIntegerType (mlirType , context);
256+ if (auto intType = mlir::dyn_cast <IntegerType>(mlirType)) {
257+ return exportIntegerType (intType , context);
265258 }
266259
267260 // Handle `FloatType`'s.
268- if (mlir::isa <FloatType>(mlirType)) {
269- return exportFloatType (mlirType , context);
261+ if (auto floatType = mlir::dyn_cast <FloatType>(mlirType)) {
262+ return exportFloatType (floatType , context);
270263 }
271264
272265 // Handle String.
0 commit comments