diff --git a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp index b7e3d5b6..1396df7b 100644 --- a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp +++ b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp @@ -37,8 +37,8 @@ struct SubstraitEmitDeduplicationPass void SubstraitEmitDeduplicationPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); populateEmitDeduplicationPatterns(patterns); - if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { Location loc = getOperation()->getLoc(); emitError(loc) << "emit deduplication: pattern application failed"; signalPassFailure(); diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 7dfbf308..c4c69673 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -150,13 +150,13 @@ std::unique_ptr SubstraitExporter::exportAny(StringAttr attr) { return any; } -std::unique_ptr exportIntegerType(mlir::Type mlirType, +/// Function that export `IntegerType`'s to the corresponding Substrait types. +std::unique_ptr exportIntegerType(IntegerType intType, MLIRContext *context) { - // Function that handles `IntegerType`'s. + assert(intType.isSigned() && "only signed integer types supported"); - // Handle SI1. - auto si1 = IntegerType::get(context, 1, IntegerType::Signed); - if (mlirType == si1) { + switch (intType.getWidth()) { + case 1: { // Handle SI1. // TODO(ingomueller): support other nullability modes. auto i1Type = std::make_unique(); i1Type->set_nullability( @@ -167,9 +167,7 @@ std::unique_ptr exportIntegerType(mlir::Type mlirType, return type; } - // Handle SI8. - auto si8 = IntegerType::get(context, 8, IntegerType::Signed); - if (mlirType == si8) { + case 8: { // Handle SI8. // TODO(ingomueller): support other nullability modes. auto i8Type = std::make_unique(); i8Type->set_nullability( @@ -180,9 +178,7 @@ std::unique_ptr exportIntegerType(mlir::Type mlirType, return type; } - // Handle SI6. - auto si16 = IntegerType::get(context, 16, IntegerType::Signed); - if (mlirType == si16) { + case 16: { // Handle SI16. // TODO(ingomueller): support other nullability modes. auto i16Type = std::make_unique(); i16Type->set_nullability( @@ -193,9 +189,7 @@ std::unique_ptr exportIntegerType(mlir::Type mlirType, return type; } - // Handle SI32. - auto si32 = IntegerType::get(context, 32, IntegerType::Signed); - if (mlirType == si32) { + case 32: { // Handle SI32. // TODO(ingomueller): support other nullability modes. auto i32Type = std::make_unique(); i32Type->set_nullability( @@ -206,9 +200,7 @@ std::unique_ptr exportIntegerType(mlir::Type mlirType, return type; } - // Handle SI64. - auto si64 = IntegerType::get(context, 64, IntegerType::Signed); - if (mlirType == si64) { + case 64: { // Handle SI64. // TODO(ingomueller): support other nullability modes. auto i64Type = std::make_unique(); i64Type->set_nullability( @@ -219,16 +211,17 @@ std::unique_ptr exportIntegerType(mlir::Type mlirType, return type; } - llvm_unreachable("We should have handled all integer types."); + default: + llvm_unreachable("We should have handled all integer types."); + } } -std::unique_ptr exportFloatType(mlir::Type mlirType, +/// Function that export `FloatType`'s to the corresponding Substrait types. +std::unique_ptr exportFloatType(FloatType floatType, MLIRContext *context) { - // Function that handles `FloatType`'s. - // Handle FP32. - auto fp32 = FloatType::getF32(context); - if (mlirType == fp32) { + switch (floatType.getWidth()) { + case 32: { // Handle FP32. // TODO(ingomueller): support other nullability modes. auto fp32Type = std::make_unique(); fp32Type->set_nullability( @@ -239,9 +232,7 @@ std::unique_ptr exportFloatType(mlir::Type mlirType, return type; } - // Handle FP64. - auto fp64 = FloatType::getF64(context); - if (mlirType == fp64) { + case 64: { // Handle FP64. // TODO(ingomueller): support other nullability modes. auto fp64Type = std::make_unique(); fp64Type->set_nullability( @@ -252,7 +243,9 @@ std::unique_ptr exportFloatType(mlir::Type mlirType, return type; } - llvm_unreachable("We should have handled all float types."); + default: + llvm_unreachable("We should have handled all float types."); + } } FailureOr> @@ -260,13 +253,13 @@ SubstraitExporter::exportType(Location loc, mlir::Type mlirType) { MLIRContext *context = mlirType.getContext(); // Handle `IntegerType`'s. - if (mlir::isa(mlirType)) { - return exportIntegerType(mlirType, context); + if (auto intType = mlir::dyn_cast(mlirType)) { + return exportIntegerType(intType, context); } // Handle `FloatType`'s. - if (mlir::isa(mlirType)) { - return exportFloatType(mlirType, context); + if (auto floatType = mlir::dyn_cast(mlirType)) { + return exportFloatType(floatType, context); } // Handle String. diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 314507fa..ae39d536 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -186,9 +186,9 @@ static mlir::FailureOr importType(MLIRContext *context, case proto::Type::kI64: return IntegerType::get(context, 64, IntegerType::Signed); case proto::Type::kFp32: - return FloatType::getF32(context); + return Float32Type::get(context); case proto::Type::kFp64: - return FloatType::getF64(context); + return Float64Type::get(context); case proto::Type::kString: return StringType::get(context); case proto::Type::kBinary: @@ -560,11 +560,11 @@ importLiteral(ImplicitLocOpBuilder builder, return builder.create(attr); } case Expression::Literal::LiteralTypeCase::kFp32: { - auto attr = FloatAttr::get(FloatType::getF32(context), message.fp32()); + auto attr = FloatAttr::get(Float32Type::get(context), message.fp32()); return builder.create(attr); } case Expression::Literal::LiteralTypeCase::kFp64: { - auto attr = FloatAttr::get(FloatType::getF64(context), message.fp64()); + auto attr = FloatAttr::get(Float64Type::get(context), message.fp64()); return builder.create(attr); } case Expression::Literal::LiteralTypeCase::kString: { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1e92959c..65819316 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,7 +18,6 @@ set(SUBSTRAIT_MLIR_TEST_DEPENDS substrait-opt substrait-translate mlir_async_runtime - mlir-cpu-runner mlir_c_runner_utils mlir_runner_utils not diff --git a/third_party/llvm-project b/third_party/llvm-project index 5dd9867e..956cfa69 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 5dd9867e2d1e698fee980e31da114a37e4c7f612 +Subproject commit 956cfa69b153a0e798060f67e713790eeefebc04