Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct SubstraitEmitDeduplicationPass
void SubstraitEmitDeduplicationPass::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
populateEmitDeduplicationPatterns(patterns);
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
Location loc = getOperation()->getLoc();
emitError(loc) << "emit deduplication: pattern application failed";
signalPassFailure();
Expand Down
55 changes: 24 additions & 31 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
return any;
}

std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
/// Function that export `IntegerType`'s to the corresponding Substrait types.
std::unique_ptr<proto::Type> exportIntegerType(IntegerType intType,
MLIRContext *context) {
// Function that handles `IntegerType`'s.
assert(intType.isSigned() && "only signed integer types supported");

// Handle SI1.
auto si1 = IntegerType::get(context, 1, IntegerType::Signed);
if (mlirType == si1) {
switch (intType.getWidth()) {
case 1: { // Handle SI1.
// TODO(ingomueller): support other nullability modes.
auto i1Type = std::make_unique<proto::Type::Boolean>();
i1Type->set_nullability(
Expand All @@ -167,9 +167,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
return type;
}

// Handle SI8.
auto si8 = IntegerType::get(context, 8, IntegerType::Signed);
if (mlirType == si8) {
case 8: { // Handle SI8.
// TODO(ingomueller): support other nullability modes.
auto i8Type = std::make_unique<proto::Type::I8>();
i8Type->set_nullability(
Expand All @@ -180,9 +178,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
return type;
}

// Handle SI6.
auto si16 = IntegerType::get(context, 16, IntegerType::Signed);
if (mlirType == si16) {
case 16: { // Handle SI16.
// TODO(ingomueller): support other nullability modes.
auto i16Type = std::make_unique<proto::Type::I16>();
i16Type->set_nullability(
Expand All @@ -193,9 +189,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
return type;
}

// Handle SI32.
auto si32 = IntegerType::get(context, 32, IntegerType::Signed);
if (mlirType == si32) {
case 32: { // Handle SI32.
// TODO(ingomueller): support other nullability modes.
auto i32Type = std::make_unique<proto::Type::I32>();
i32Type->set_nullability(
Expand All @@ -206,9 +200,7 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
return type;
}

// Handle SI64.
auto si64 = IntegerType::get(context, 64, IntegerType::Signed);
if (mlirType == si64) {
case 64: { // Handle SI64.
// TODO(ingomueller): support other nullability modes.
auto i64Type = std::make_unique<proto::Type::I64>();
i64Type->set_nullability(
Expand All @@ -219,16 +211,17 @@ std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
return type;
}

llvm_unreachable("We should have handled all integer types.");
default:
llvm_unreachable("We should have handled all integer types.");
}
}

std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
/// Function that export `FloatType`'s to the corresponding Substrait types.
std::unique_ptr<proto::Type> exportFloatType(FloatType floatType,
MLIRContext *context) {
// Function that handles `FloatType`'s.

// Handle FP32.
auto fp32 = FloatType::getF32(context);
if (mlirType == fp32) {
switch (floatType.getWidth()) {
case 32: { // Handle FP32.
// TODO(ingomueller): support other nullability modes.
auto fp32Type = std::make_unique<proto::Type::FP32>();
fp32Type->set_nullability(
Expand All @@ -239,9 +232,7 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
return type;
}

// Handle FP64.
auto fp64 = FloatType::getF64(context);
if (mlirType == fp64) {
case 64: { // Handle FP64.
// TODO(ingomueller): support other nullability modes.
auto fp64Type = std::make_unique<proto::Type::FP64>();
fp64Type->set_nullability(
Expand All @@ -252,21 +243,23 @@ std::unique_ptr<proto::Type> exportFloatType(mlir::Type mlirType,
return type;
}

llvm_unreachable("We should have handled all float types.");
default:
llvm_unreachable("We should have handled all float types.");
}
}

FailureOr<std::unique_ptr<proto::Type>>
SubstraitExporter::exportType(Location loc, mlir::Type mlirType) {
MLIRContext *context = mlirType.getContext();

// Handle `IntegerType`'s.
if (mlir::isa<IntegerType>(mlirType)) {
return exportIntegerType(mlirType, context);
if (auto intType = mlir::dyn_cast<IntegerType>(mlirType)) {
return exportIntegerType(intType, context);
}

// Handle `FloatType`'s.
if (mlir::isa<FloatType>(mlirType)) {
return exportFloatType(mlirType, context);
if (auto floatType = mlir::dyn_cast<FloatType>(mlirType)) {
return exportFloatType(floatType, context);
}

// Handle String.
Expand Down
8 changes: 4 additions & 4 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
case proto::Type::kI64:
return IntegerType::get(context, 64, IntegerType::Signed);
case proto::Type::kFp32:
return FloatType::getF32(context);
return Float32Type::get(context);
case proto::Type::kFp64:
return FloatType::getF64(context);
return Float64Type::get(context);
case proto::Type::kString:
return StringType::get(context);
case proto::Type::kBinary:
Expand Down Expand Up @@ -560,11 +560,11 @@ importLiteral(ImplicitLocOpBuilder builder,
return builder.create<LiteralOp>(attr);
}
case Expression::Literal::LiteralTypeCase::kFp32: {
auto attr = FloatAttr::get(FloatType::getF32(context), message.fp32());
auto attr = FloatAttr::get(Float32Type::get(context), message.fp32());
return builder.create<LiteralOp>(attr);
}
case Expression::Literal::LiteralTypeCase::kFp64: {
auto attr = FloatAttr::get(FloatType::getF64(context), message.fp64());
auto attr = FloatAttr::get(Float64Type::get(context), message.fp64());
return builder.create<LiteralOp>(attr);
}
case Expression::Literal::LiteralTypeCase::kString: {
Expand Down
1 change: 0 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ set(SUBSTRAIT_MLIR_TEST_DEPENDS
substrait-opt
substrait-translate
mlir_async_runtime
mlir-cpu-runner
mlir_c_runner_utils
mlir_runner_utils
not
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project
Submodule llvm-project updated 23260 files