diff --git a/src/mlir/cxx/mlir/CMakeLists.txt b/src/mlir/cxx/mlir/CMakeLists.txt index 03aff80a..f9b19b69 100644 --- a/src/mlir/cxx/mlir/CMakeLists.txt +++ b/src/mlir/cxx/mlir/CMakeLists.txt @@ -28,10 +28,11 @@ target_include_directories(cxx-mlir target_link_libraries(cxx-mlir PUBLIC cxx-parser - MLIRIR - MLIRFuncDialect MLIRControlFlowDialect - MLIRSCFDialect + MLIRControlFlowToLLVM + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRPass MLIRTransforms ) diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index eb0bbe96..8e8a37bd 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -28,10 +28,7 @@ def Cxx_Dialect : Dialect { let name = "cxx"; let cppNamespace = "mlir::cxx"; let useDefaultTypePrinterParser = 1; - let dependentDialects = ["mlir::func::FuncDialect", - "mlir::cf::ControlFlowDialect", - "mlir::scf::SCFDialect", - ]; + let dependentDialects = ["mlir::cf::ControlFlowDialect" ]; } class Cxx_Type traits = []> diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 368ae621..14d67dc2 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -24,7 +24,10 @@ #include // mlir -#include +#include +#include +#include +#include #include #include #include @@ -40,11 +43,47 @@ class FuncOpLowering : public OpConversionPattern { auto matchAndRewrite(cxx::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const -> LogicalResult override { - auto func = rewriter.create(op.getLoc(), op.getName(), - op.getFunctionType()); + auto typeConverter = getTypeConverter(); + + auto funcType = op.getFunctionType(); + + SmallVector argumentTypes; + for (auto argType : funcType.getInputs()) { + auto convertedType = typeConverter->convertType(argType); + if (!convertedType) { + return rewriter.notifyMatchFailure( + op, "failed to convert function argument type"); + } + argumentTypes.push_back(convertedType); + } + + SmallVector resultTypes; + for (auto resultType : funcType.getResults()) { + auto convertedType = typeConverter->convertType(resultType); + if (!convertedType) { + return rewriter.notifyMatchFailure( + op, "failed to convert function result type"); + } + + resultTypes.push_back(convertedType); + } + + const auto returnType = resultTypes.empty() + ? LLVM::LLVMVoidType::get(getContext()) + : resultTypes.front(); + + const auto isVarArg = false; + + auto llvmFuncType = + LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg); + + auto func = rewriter.create(op.getLoc(), op.getSymName(), + llvmFuncType); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); + rewriter.eraseOp(op); + return success(); } }; @@ -56,21 +95,113 @@ class ReturnOpLowering : public OpConversionPattern { auto matchAndRewrite(cxx::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const -> LogicalResult override { - rewriter.replaceOpWithNewOp(op, - adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +class AllocaOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::AllocaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto ptrTy = dyn_cast(op.getType()); + if (!ptrTy) { + return rewriter.notifyMatchFailure( + op, "expected result type to be a pointer type"); + } + + auto resultType = LLVM::LLVMPointerType::get(context); + auto elementType = typeConverter->convertType(ptrTy.getElementType()); + + auto size = rewriter.create( + op.getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); + + auto x = rewriter.replaceOpWithNewOp(op, resultType, + elementType, size); + return success(); + } +}; + +class LoadOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getAddr()); + + return success(); + } +}; + +class StoreOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto valueType = typeConverter->convertType(op.getValue().getType()); + if (!valueType) { + return rewriter.notifyMatchFailure(op, + "failed to convert store value type"); + } + + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), + adaptor.getAddr()); + return success(); } }; -class CxxToFuncLoweringPass - : public PassWrapper> { +class IntConstantOpLowering : public OpConversionPattern { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToFuncLoweringPass) + using OpConversionPattern::OpConversionPattern; - auto getArgument() const -> StringRef override { return "cxx-to-func"; } + auto matchAndRewrite(cxx::IntConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert integer constant type"); + } + + auto valueAttr = adaptor.getValueAttr(); + rewriter.replaceOpWithNewOp(op, resultType, valueAttr); + + return success(); + } +}; + +class CxxToLLVMLoweringPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToLLVMLoweringPass) + + auto getArgument() const -> StringRef override { return "cxx-to-llvm"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() final; @@ -78,21 +209,38 @@ class CxxToFuncLoweringPass } // namespace -void CxxToFuncLoweringPass::runOnOperation() { - ConversionTarget target(getContext()); +void CxxToLLVMLoweringPass::runOnOperation() { + auto context = &getContext(); + auto module = getOperation(); + + // set up the data layout + mlir::DataLayout dataLayout(module); - // good - target.addLegalDialect(); + // set up the type converter + LLVMTypeConverter typeConverter{context}; + typeConverter.addConversion([](cxx::IntegerType type) { + return IntegerType::get(type.getContext(), type.getWidth()); + }); - // illegal - target.addIllegalOp(); - target.addIllegalOp(); + typeConverter.addConversion([](cxx::PointerType type) { + return LLVM::LLVMPointerType::get(type.getContext()); + }); - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - patterns.insert(&getContext()); + // set up the conversion patterns + ConversionTarget target(*context); - auto module = getOperation(); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(context); + patterns.insert( + typeConverter, context); + + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + + cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); @@ -101,16 +249,22 @@ void CxxToFuncLoweringPass::runOnOperation() { } // namespace mlir -auto cxx::createLowerToFuncPass() -> std::unique_ptr { - return std::make_unique(); +auto cxx::createLowerToLLVMPass() -> std::unique_ptr { + return std::make_unique(); } auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult { mlir::PassManager pm(module->getName()); - pm.addPass(cxx::createLowerToFuncPass()); + // debug dialect conversions +#if false + module->getContext()->disableMultithreading(); + pm.enableIRPrinting(); +#endif + + pm.addPass(cxx::createLowerToLLVMPass()); - if (mlir::failed(pm.run(module))) { + if (failed(pm.run(module))) { return mlir::failure(); } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.h b/src/mlir/cxx/mlir/cxx_dialect_conversions.h index 2618d9cb..8631b252 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.h +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.h @@ -25,7 +25,7 @@ namespace cxx { -[[nodiscard]] auto createLowerToFuncPass() -> std::unique_ptr; +[[nodiscard]] auto createLowerToLLVMPass() -> std::unique_ptr; [[nodiscard]] auto lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult;