@@ -101,13 +101,13 @@ std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
101101 return any;
102102}
103103
104- std::unique_ptr<proto::Type> exportIntegerType (mlir::Type mlirType,
104+ // / Function that export `IntegerType`'s to the corresponding Substrait types.
105+ std::unique_ptr<proto::Type> exportIntegerType (IntegerType intType,
105106 MLIRContext *context) {
106- // Function that handles `IntegerType`'s.
107+ assert (intType. isSigned () && " only signed integer types supported " );
107108
108- // Handle SI1.
109- auto si1 = IntegerType::get (context, 1 , IntegerType::Signed);
110- if (mlirType == si1) {
109+ switch (intType.getWidth ()) {
110+ case 1 : { // Handle SI1.
111111 // TODO(ingomueller): support other nullability modes.
112112 auto i1Type = std::make_unique<proto::Type::Boolean>();
113113 i1Type->set_nullability (
@@ -118,9 +118,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
118118 return type;
119119 }
120120
121- // Handle SI8.
122- auto si8 = IntegerType::get (context, 8 , IntegerType::Signed);
123- if (mlirType == si8) {
121+ case 8 : { // Handle SI8.
124122 // TODO(ingomueller): support other nullability modes.
125123 auto i8Type = std::make_unique<proto::Type::I8>();
126124 i8Type->set_nullability (
@@ -131,9 +129,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
131129 return type;
132130 }
133131
134- // Handle SI6.
135- auto si16 = IntegerType::get (context, 16 , IntegerType::Signed);
136- if (mlirType == si16) {
132+ case 16 : { // Handle SI16.
137133 // TODO(ingomueller): support other nullability modes.
138134 auto i16Type = std::make_unique<proto::Type::I16>();
139135 i16Type->set_nullability (
@@ -144,9 +140,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
144140 return type;
145141 }
146142
147- // Handle SI32.
148- auto si32 = IntegerType::get (context, 32 , IntegerType::Signed);
149- if (mlirType == si32) {
143+ case 32 : { // Handle SI32.
150144 // TODO(ingomueller): support other nullability modes.
151145 auto i32Type = std::make_unique<proto::Type::I32>();
152146 i32Type->set_nullability (
@@ -157,9 +151,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
157151 return type;
158152 }
159153
160- // Handle SI64.
161- auto si64 = IntegerType::get (context, 64 , IntegerType::Signed);
162- if (mlirType == si64) {
154+ case 64 : { // Handle SI64.
163155 // TODO(ingomueller): support other nullability modes.
164156 auto i64Type = std::make_unique<proto::Type::I64>();
165157 i64Type->set_nullability (
@@ -170,16 +162,17 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
170162 return type;
171163 }
172164
173- llvm_unreachable (" We should have handled all integer types." );
165+ default :
166+ llvm_unreachable (" We should have handled all integer types." );
167+ }
174168}
175169
176- std::unique_ptr<proto::Type> exportFloatType (mlir::Type mlirType,
170+ // / Function that export `FloatType`'s to the corresponding Substrait types.
171+ std::unique_ptr<proto::Type> exportFloatType (FloatType floatType,
177172 MLIRContext *context) {
178- // Function that handles `FloatType`'s.
179173
180- // Handle FP32.
181- auto fp32 = FloatType::getF32 (context);
182- if (mlirType == fp32) {
174+ switch (floatType.getWidth ()) {
175+ case 32 : { // Handle FP32.
183176 // TODO(ingomueller): support other nullability modes.
184177 auto fp32Type = std::make_unique<proto::Type::FP32>();
185178 fp32Type->set_nullability (
@@ -190,9 +183,7 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
190183 return type;
191184 }
192185
193- // Handle FP64.
194- auto fp64 = FloatType::getF64 (context);
195- if (mlirType == fp64) {
186+ case 64 : { // Handle FP64.
196187 // TODO(ingomueller): support other nullability modes.
197188 auto fp64Type = std::make_unique<proto::Type::FP64>();
198189 fp64Type->set_nullability (
@@ -203,21 +194,23 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
203194 return type;
204195 }
205196
206- llvm_unreachable (" We should have handled all float types." );
197+ default :
198+ llvm_unreachable (" We should have handled all float types." );
199+ }
207200}
208201
209202FailureOr<std::unique_ptr<proto::Type>>
210203SubstraitExporter::exportType (Location loc, mlir::Type mlirType) {
211204 MLIRContext *context = mlirType.getContext ();
212205
213206 // Handle `IntegerType`'s.
214- if (mlir::isa <IntegerType>(mlirType)) {
215- return exportIntegerType (mlirType , context);
207+ if (auto intType = mlir::dyn_cast <IntegerType>(mlirType)) {
208+ return exportIntegerType (intType , context);
216209 }
217210
218211 // Handle `FloatType`'s.
219- if (mlir::isa <FloatType>(mlirType)) {
220- return exportFloatType (mlirType , context);
212+ if (auto floatType = mlir::dyn_cast <FloatType>(mlirType)) {
213+ return exportFloatType (floatType , context);
221214 }
222215
223216 // Handle String.
0 commit comments