Skip to content

Commit 1f1be74

Browse files
committed
Lower cxx to the LLVM dialect
1 parent 95bc4cd commit 1f1be74

File tree

4 files changed

+185
-33
lines changed

4 files changed

+185
-33
lines changed

src/mlir/cxx/mlir/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ target_include_directories(cxx-mlir
2828

2929
target_link_libraries(cxx-mlir PUBLIC
3030
cxx-parser
31-
MLIRIR
32-
MLIRFuncDialect
3331
MLIRControlFlowDialect
34-
MLIRSCFDialect
32+
MLIRControlFlowToLLVM
33+
MLIRIR
34+
MLIRLLVMCommonConversion
35+
MLIRLLVMDialect
3536
MLIRPass
3637
MLIRTransforms
3738
)

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ def Cxx_Dialect : Dialect {
2828
let name = "cxx";
2929
let cppNamespace = "mlir::cxx";
3030
let useDefaultTypePrinterParser = 1;
31-
let dependentDialects = ["mlir::func::FuncDialect",
32-
"mlir::cf::ControlFlowDialect",
33-
"mlir::scf::SCFDialect",
34-
];
31+
let dependentDialects = ["mlir::cf::ControlFlowDialect" ];
3532
}
3633

3734
class Cxx_Type<string name, string typeMnemonic, list<Trait> traits = []>

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 179 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
#include <cxx/mlir/cxx_dialect.h>
2525

2626
// mlir
27-
#include <mlir/Dialect/Func/IR/FuncOps.h>
27+
#include <mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
28+
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
29+
#include <mlir/Dialect/ControlFlow/IR/ControlFlow.h>
30+
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
2831
#include <mlir/Pass/Pass.h>
2932
#include <mlir/Pass/PassManager.h>
3033
#include <mlir/Transforms/DialectConversion.h>
@@ -40,11 +43,47 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
4043
auto matchAndRewrite(cxx::FuncOp op, OpAdaptor adaptor,
4144
ConversionPatternRewriter &rewriter) const
4245
-> LogicalResult override {
43-
auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(),
44-
op.getFunctionType());
46+
auto typeConverter = getTypeConverter();
47+
48+
auto funcType = op.getFunctionType();
49+
50+
SmallVector<Type> argumentTypes;
51+
for (auto argType : funcType.getInputs()) {
52+
auto convertedType = typeConverter->convertType(argType);
53+
if (!convertedType) {
54+
return rewriter.notifyMatchFailure(
55+
op, "failed to convert function argument type");
56+
}
57+
argumentTypes.push_back(convertedType);
58+
}
59+
60+
SmallVector<Type> resultTypes;
61+
for (auto resultType : funcType.getResults()) {
62+
auto convertedType = typeConverter->convertType(resultType);
63+
if (!convertedType) {
64+
return rewriter.notifyMatchFailure(
65+
op, "failed to convert function result type");
66+
}
67+
68+
resultTypes.push_back(convertedType);
69+
}
70+
71+
const auto returnType = resultTypes.empty()
72+
? LLVM::LLVMVoidType::get(getContext())
73+
: resultTypes.front();
74+
75+
const auto isVarArg = false;
76+
77+
auto llvmFuncType =
78+
LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg);
79+
80+
auto func = rewriter.create<LLVM::LLVMFuncOp>(op.getLoc(), op.getSymName(),
81+
llvmFuncType);
4582

4683
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
84+
4785
rewriter.eraseOp(op);
86+
4887
return success();
4988
}
5089
};
@@ -56,43 +95,152 @@ class ReturnOpLowering : public OpConversionPattern<cxx::ReturnOp> {
5695
auto matchAndRewrite(cxx::ReturnOp op, OpAdaptor adaptor,
5796
ConversionPatternRewriter &rewriter) const
5897
-> LogicalResult override {
59-
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op,
60-
adaptor.getOperands());
98+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
99+
return success();
100+
}
101+
};
102+
103+
class AllocaOpLowering : public OpConversionPattern<cxx::AllocaOp> {
104+
public:
105+
using OpConversionPattern::OpConversionPattern;
106+
107+
auto matchAndRewrite(cxx::AllocaOp op, OpAdaptor adaptor,
108+
ConversionPatternRewriter &rewriter) const
109+
-> LogicalResult override {
110+
auto typeConverter = getTypeConverter();
111+
auto context = getContext();
112+
113+
auto ptrTy = dyn_cast<cxx::PointerType>(op.getType());
114+
if (!ptrTy) {
115+
return rewriter.notifyMatchFailure(
116+
op, "expected result type to be a pointer type");
117+
}
118+
119+
auto resultType = LLVM::LLVMPointerType::get(context);
120+
auto elementType = typeConverter->convertType(ptrTy.getElementType());
121+
122+
auto size = rewriter.create<LLVM::ConstantOp>(
123+
op.getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
124+
125+
auto x = rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(op, resultType,
126+
elementType, size);
127+
return success();
128+
}
129+
};
130+
131+
class LoadOpLowering : public OpConversionPattern<cxx::LoadOp> {
132+
public:
133+
using OpConversionPattern::OpConversionPattern;
134+
135+
auto matchAndRewrite(cxx::LoadOp op, OpAdaptor adaptor,
136+
ConversionPatternRewriter &rewriter) const
137+
-> LogicalResult override {
138+
auto typeConverter = getTypeConverter();
139+
auto context = getContext();
140+
141+
auto resultType = typeConverter->convertType(op.getType());
142+
143+
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultType,
144+
adaptor.getAddr());
145+
146+
return success();
147+
}
148+
};
149+
150+
class StoreOpLowering : public OpConversionPattern<cxx::StoreOp> {
151+
public:
152+
using OpConversionPattern::OpConversionPattern;
153+
154+
auto matchAndRewrite(cxx::StoreOp op, OpAdaptor adaptor,
155+
ConversionPatternRewriter &rewriter) const
156+
-> LogicalResult override {
157+
auto typeConverter = getTypeConverter();
158+
auto context = getContext();
159+
160+
auto valueType = typeConverter->convertType(op.getValue().getType());
161+
if (!valueType) {
162+
return rewriter.notifyMatchFailure(op,
163+
"failed to convert store value type");
164+
}
165+
166+
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
167+
adaptor.getAddr());
168+
61169
return success();
62170
}
63171
};
64172

65-
class CxxToFuncLoweringPass
66-
: public PassWrapper<CxxToFuncLoweringPass, OperationPass<ModuleOp>> {
173+
class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
67174
public:
68-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToFuncLoweringPass)
175+
using OpConversionPattern::OpConversionPattern;
69176

70-
auto getArgument() const -> StringRef override { return "cxx-to-func"; }
177+
auto matchAndRewrite(cxx::IntConstantOp op, OpAdaptor adaptor,
178+
ConversionPatternRewriter &rewriter) const
179+
-> LogicalResult override {
180+
auto typeConverter = getTypeConverter();
181+
auto context = getContext();
182+
183+
auto resultType = typeConverter->convertType(op.getType());
184+
if (!resultType) {
185+
return rewriter.notifyMatchFailure(
186+
op, "failed to convert integer constant type");
187+
}
188+
189+
auto valueAttr = adaptor.getValueAttr();
190+
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType, valueAttr);
191+
192+
return success();
193+
}
194+
};
195+
196+
class CxxToLLVMLoweringPass
197+
: public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
198+
public:
199+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToLLVMLoweringPass)
200+
201+
auto getArgument() const -> StringRef override { return "cxx-to-llvm"; }
71202

72203
void getDependentDialects(DialectRegistry &registry) const override {
73-
registry.insert<func::FuncDialect>();
204+
registry.insert<LLVM::LLVMDialect>();
74205
}
75206

76207
void runOnOperation() final;
77208
};
78209

79210
} // namespace
80211

81-
void CxxToFuncLoweringPass::runOnOperation() {
82-
ConversionTarget target(getContext());
212+
void CxxToLLVMLoweringPass::runOnOperation() {
213+
auto context = &getContext();
214+
auto module = getOperation();
215+
216+
// set up the data layout
217+
mlir::DataLayout dataLayout(module);
83218

84-
// good
85-
target.addLegalDialect<func::FuncDialect>();
219+
// set up the type converter
220+
LLVMTypeConverter typeConverter{context};
221+
typeConverter.addConversion([](cxx::IntegerType type) {
222+
return IntegerType::get(type.getContext(), type.getWidth());
223+
});
86224

87-
// illegal
88-
target.addIllegalOp<cxx::FuncOp>();
89-
target.addIllegalOp<cxx::ReturnOp>();
225+
typeConverter.addConversion([](cxx::PointerType type) {
226+
return LLVM::LLVMPointerType::get(type.getContext());
227+
});
90228

91-
RewritePatternSet patterns(&getContext());
92-
patterns.insert<FuncOpLowering>(&getContext());
93-
patterns.insert<ReturnOpLowering>(&getContext());
229+
// set up the conversion patterns
230+
ConversionTarget target(*context);
94231

95-
auto module = getOperation();
232+
target.addLegalDialect<LLVM::LLVMDialect>();
233+
target.addIllegalDialect<cxx::CxxDialect>();
234+
235+
RewritePatternSet patterns(context);
236+
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
237+
LoadOpLowering, StoreOpLowering, IntConstantOpLowering>(
238+
typeConverter, context);
239+
240+
populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
241+
typeConverter);
242+
243+
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
96244

97245
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
98246
signalPassFailure();
@@ -101,16 +249,22 @@ void CxxToFuncLoweringPass::runOnOperation() {
101249

102250
} // namespace mlir
103251

104-
auto cxx::createLowerToFuncPass() -> std::unique_ptr<mlir::Pass> {
105-
return std::make_unique<mlir::CxxToFuncLoweringPass>();
252+
auto cxx::createLowerToLLVMPass() -> std::unique_ptr<mlir::Pass> {
253+
return std::make_unique<mlir::CxxToLLVMLoweringPass>();
106254
}
107255

108256
auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult {
109257
mlir::PassManager pm(module->getName());
110258

111-
pm.addPass(cxx::createLowerToFuncPass());
259+
// debug dialect conversions
260+
#if false
261+
module->getContext()->disableMultithreading();
262+
pm.enableIRPrinting();
263+
#endif
264+
265+
pm.addPass(cxx::createLowerToLLVMPass());
112266

113-
if (mlir::failed(pm.run(module))) {
267+
if (failed(pm.run(module))) {
114268
return mlir::failure();
115269
}
116270

src/mlir/cxx/mlir/cxx_dialect_conversions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
namespace cxx {
2727

28-
[[nodiscard]] auto createLowerToFuncPass() -> std::unique_ptr<mlir::Pass>;
28+
[[nodiscard]] auto createLowerToLLVMPass() -> std::unique_ptr<mlir::Pass>;
2929

3030
[[nodiscard]] auto lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult;
3131

0 commit comments

Comments
 (0)