2424#include < cxx/mlir/cxx_dialect.h>
2525
2626// mlir
27- #include < mlir/Dialect/Func/IR/FuncOps.h>
27+ #include < mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
28+ #include < mlir/Conversion/LLVMCommon/TypeConverter.h>
29+ #include < mlir/Dialect/ControlFlow/IR/ControlFlow.h>
30+ #include < mlir/Dialect/LLVMIR/LLVMDialect.h>
2831#include < mlir/Pass/Pass.h>
2932#include < mlir/Pass/PassManager.h>
3033#include < mlir/Transforms/DialectConversion.h>
@@ -40,11 +43,47 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
4043 auto matchAndRewrite (cxx::FuncOp op, OpAdaptor adaptor,
4144 ConversionPatternRewriter &rewriter) const
4245 -> LogicalResult override {
43- auto func = rewriter.create <mlir::func::FuncOp>(op.getLoc (), op.getName (),
44- op.getFunctionType ());
46+ auto typeConverter = getTypeConverter ();
47+
48+ auto funcType = op.getFunctionType ();
49+
50+ SmallVector<Type> argumentTypes;
51+ for (auto argType : funcType.getInputs ()) {
52+ auto convertedType = typeConverter->convertType (argType);
53+ if (!convertedType) {
54+ return rewriter.notifyMatchFailure (
55+ op, " failed to convert function argument type" );
56+ }
57+ argumentTypes.push_back (convertedType);
58+ }
59+
60+ SmallVector<Type> resultTypes;
61+ for (auto resultType : funcType.getResults ()) {
62+ auto convertedType = typeConverter->convertType (resultType);
63+ if (!convertedType) {
64+ return rewriter.notifyMatchFailure (
65+ op, " failed to convert function result type" );
66+ }
67+
68+ resultTypes.push_back (convertedType);
69+ }
70+
71+ const auto returnType = resultTypes.empty ()
72+ ? LLVM::LLVMVoidType::get (getContext ())
73+ : resultTypes.front ();
74+
75+ const auto isVarArg = false ;
76+
77+ auto llvmFuncType =
78+ LLVM::LLVMFunctionType::get (returnType, argumentTypes, isVarArg);
79+
80+ auto func = rewriter.create <LLVM::LLVMFuncOp>(op.getLoc (), op.getSymName (),
81+ llvmFuncType);
4582
4683 rewriter.inlineRegionBefore (op.getRegion (), func.getBody (), func.end ());
84+
4785 rewriter.eraseOp (op);
86+
4887 return success ();
4988 }
5089};
@@ -56,43 +95,152 @@ class ReturnOpLowering : public OpConversionPattern<cxx::ReturnOp> {
5695 auto matchAndRewrite (cxx::ReturnOp op, OpAdaptor adaptor,
5796 ConversionPatternRewriter &rewriter) const
5897 -> LogicalResult override {
59- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(op,
60- adaptor.getOperands ());
98+ rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(op, adaptor.getOperands ());
99+ return success ();
100+ }
101+ };
102+
103+ class AllocaOpLowering : public OpConversionPattern <cxx::AllocaOp> {
104+ public:
105+ using OpConversionPattern::OpConversionPattern;
106+
107+ auto matchAndRewrite (cxx::AllocaOp op, OpAdaptor adaptor,
108+ ConversionPatternRewriter &rewriter) const
109+ -> LogicalResult override {
110+ auto typeConverter = getTypeConverter ();
111+ auto context = getContext ();
112+
113+ auto ptrTy = dyn_cast<cxx::PointerType>(op.getType ());
114+ if (!ptrTy) {
115+ return rewriter.notifyMatchFailure (
116+ op, " expected result type to be a pointer type" );
117+ }
118+
119+ auto resultType = LLVM::LLVMPointerType::get (context);
120+ auto elementType = typeConverter->convertType (ptrTy.getElementType ());
121+
122+ auto size = rewriter.create <LLVM::ConstantOp>(
123+ op.getLoc (), rewriter.getI64Type (), rewriter.getI64IntegerAttr (1 ));
124+
125+ auto x = rewriter.replaceOpWithNewOp <LLVM::AllocaOp>(op, resultType,
126+ elementType, size);
127+ return success ();
128+ }
129+ };
130+
131+ class LoadOpLowering : public OpConversionPattern <cxx::LoadOp> {
132+ public:
133+ using OpConversionPattern::OpConversionPattern;
134+
135+ auto matchAndRewrite (cxx::LoadOp op, OpAdaptor adaptor,
136+ ConversionPatternRewriter &rewriter) const
137+ -> LogicalResult override {
138+ auto typeConverter = getTypeConverter ();
139+ auto context = getContext ();
140+
141+ auto resultType = typeConverter->convertType (op.getType ());
142+
143+ rewriter.replaceOpWithNewOp <LLVM::LoadOp>(op, resultType,
144+ adaptor.getAddr ());
145+
146+ return success ();
147+ }
148+ };
149+
150+ class StoreOpLowering : public OpConversionPattern <cxx::StoreOp> {
151+ public:
152+ using OpConversionPattern::OpConversionPattern;
153+
154+ auto matchAndRewrite (cxx::StoreOp op, OpAdaptor adaptor,
155+ ConversionPatternRewriter &rewriter) const
156+ -> LogicalResult override {
157+ auto typeConverter = getTypeConverter ();
158+ auto context = getContext ();
159+
160+ auto valueType = typeConverter->convertType (op.getValue ().getType ());
161+ if (!valueType) {
162+ return rewriter.notifyMatchFailure (op,
163+ " failed to convert store value type" );
164+ }
165+
166+ rewriter.replaceOpWithNewOp <LLVM::StoreOp>(op, adaptor.getValue (),
167+ adaptor.getAddr ());
168+
61169 return success ();
62170 }
63171};
64172
65- class CxxToFuncLoweringPass
66- : public PassWrapper<CxxToFuncLoweringPass, OperationPass<ModuleOp>> {
173+ class IntConstantOpLowering : public OpConversionPattern <cxx::IntConstantOp> {
67174 public:
68- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (CxxToFuncLoweringPass)
175+ using OpConversionPattern::OpConversionPattern;
69176
70- auto getArgument () const -> StringRef override { return " cxx-to-func" ; }
177+ auto matchAndRewrite (cxx::IntConstantOp op, OpAdaptor adaptor,
178+ ConversionPatternRewriter &rewriter) const
179+ -> LogicalResult override {
180+ auto typeConverter = getTypeConverter ();
181+ auto context = getContext ();
182+
183+ auto resultType = typeConverter->convertType (op.getType ());
184+ if (!resultType) {
185+ return rewriter.notifyMatchFailure (
186+ op, " failed to convert integer constant type" );
187+ }
188+
189+ auto valueAttr = adaptor.getValueAttr ();
190+ rewriter.replaceOpWithNewOp <LLVM::ConstantOp>(op, resultType, valueAttr);
191+
192+ return success ();
193+ }
194+ };
195+
196+ class CxxToLLVMLoweringPass
197+ : public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
198+ public:
199+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (CxxToLLVMLoweringPass)
200+
201+ auto getArgument () const -> StringRef override { return " cxx-to-llvm" ; }
71202
72203 void getDependentDialects (DialectRegistry ®istry) const override {
73- registry.insert <func::FuncDialect >();
204+ registry.insert <LLVM::LLVMDialect >();
74205 }
75206
76207 void runOnOperation () final ;
77208};
78209
79210} // namespace
80211
81- void CxxToFuncLoweringPass::runOnOperation () {
82- ConversionTarget target (getContext ());
212+ void CxxToLLVMLoweringPass::runOnOperation () {
213+ auto context = &getContext ();
214+ auto module = getOperation ();
215+
216+ // set up the data layout
217+ mlir::DataLayout dataLayout (module );
83218
84- // good
85- target.addLegalDialect <func::FuncDialect>();
219+ // set up the type converter
220+ LLVMTypeConverter typeConverter{context};
221+ typeConverter.addConversion ([](cxx::IntegerType type) {
222+ return IntegerType::get (type.getContext (), type.getWidth ());
223+ });
86224
87- // illegal
88- target. addIllegalOp <cxx::FuncOp>( );
89- target. addIllegalOp <cxx::ReturnOp>( );
225+ typeConverter. addConversion ([](cxx::PointerType type) {
226+ return LLVM::LLVMPointerType::get (type. getContext () );
227+ } );
90228
91- RewritePatternSet patterns (&getContext ());
92- patterns.insert <FuncOpLowering>(&getContext ());
93- patterns.insert <ReturnOpLowering>(&getContext ());
229+ // set up the conversion patterns
230+ ConversionTarget target (*context);
94231
95- auto module = getOperation ();
232+ target.addLegalDialect <LLVM::LLVMDialect>();
233+ target.addIllegalDialect <cxx::CxxDialect>();
234+
235+ RewritePatternSet patterns (context);
236+ patterns.insert <FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
237+ LoadOpLowering, StoreOpLowering, IntConstantOpLowering>(
238+ typeConverter, context);
239+
240+ populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
241+ typeConverter);
242+
243+ cf::populateControlFlowToLLVMConversionPatterns (typeConverter, patterns);
96244
97245 if (failed (applyPartialConversion (module , target, std::move (patterns)))) {
98246 signalPassFailure ();
@@ -101,16 +249,22 @@ void CxxToFuncLoweringPass::runOnOperation() {
101249
102250} // namespace mlir
103251
104- auto cxx::createLowerToFuncPass () -> std::unique_ptr<mlir::Pass> {
105- return std::make_unique<mlir::CxxToFuncLoweringPass >();
252+ auto cxx::createLowerToLLVMPass () -> std::unique_ptr<mlir::Pass> {
253+ return std::make_unique<mlir::CxxToLLVMLoweringPass >();
106254}
107255
108256auto cxx::lowerToMLIR (mlir::ModuleOp module ) -> mlir::LogicalResult {
109257 mlir::PassManager pm (module ->getName ());
110258
111- pm.addPass (cxx::createLowerToFuncPass ());
259+ // debug dialect conversions
260+ #if false
261+ module ->getContext ()->disableMultithreading ();
262+ pm.enableIRPrinting ();
263+ #endif
264+
265+ pm.addPass (cxx::createLowerToLLVMPass ());
112266
113- if (mlir:: failed (pm.run (module ))) {
267+ if (failed (pm.run (module ))) {
114268 return mlir::failure ();
115269 }
116270
0 commit comments