Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/mlir/cxx/mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ target_include_directories(cxx-mlir

target_link_libraries(cxx-mlir PUBLIC
cxx-parser
MLIRIR
MLIRFuncDialect
MLIRControlFlowDialect
MLIRSCFDialect
MLIRControlFlowToLLVM
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRPass
MLIRTransforms
)
Expand Down
5 changes: 1 addition & 4 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ def Cxx_Dialect : Dialect {
let name = "cxx";
let cppNamespace = "mlir::cxx";
let useDefaultTypePrinterParser = 1;
let dependentDialects = ["mlir::func::FuncDialect",
"mlir::cf::ControlFlowDialect",
"mlir::scf::SCFDialect",
];
let dependentDialects = ["mlir::cf::ControlFlowDialect" ];
}

class Cxx_Type<string name, string typeMnemonic, list<Trait> traits = []>
Expand Down
204 changes: 179 additions & 25 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
#include <cxx/mlir/cxx_dialect.h>

// mlir
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlow.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/DialectConversion.h>
Expand All @@ -40,11 +43,47 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
auto matchAndRewrite(cxx::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(),
op.getFunctionType());
auto typeConverter = getTypeConverter();

auto funcType = op.getFunctionType();

SmallVector<Type> argumentTypes;
for (auto argType : funcType.getInputs()) {
auto convertedType = typeConverter->convertType(argType);
if (!convertedType) {
return rewriter.notifyMatchFailure(
op, "failed to convert function argument type");
}
argumentTypes.push_back(convertedType);
}

SmallVector<Type> resultTypes;
for (auto resultType : funcType.getResults()) {
auto convertedType = typeConverter->convertType(resultType);
if (!convertedType) {
return rewriter.notifyMatchFailure(
op, "failed to convert function result type");
}

resultTypes.push_back(convertedType);
}

const auto returnType = resultTypes.empty()
? LLVM::LLVMVoidType::get(getContext())
: resultTypes.front();

const auto isVarArg = false;

auto llvmFuncType =
LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg);

auto func = rewriter.create<LLVM::LLVMFuncOp>(op.getLoc(), op.getSymName(),
llvmFuncType);

rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());

rewriter.eraseOp(op);

return success();
}
};
Expand All @@ -56,43 +95,152 @@ class ReturnOpLowering : public OpConversionPattern<cxx::ReturnOp> {
auto matchAndRewrite(cxx::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};

class AllocaOpLowering : public OpConversionPattern<cxx::AllocaOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::AllocaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto ptrTy = dyn_cast<cxx::PointerType>(op.getType());
if (!ptrTy) {
return rewriter.notifyMatchFailure(
op, "expected result type to be a pointer type");
}

auto resultType = LLVM::LLVMPointerType::get(context);
auto elementType = typeConverter->convertType(ptrTy.getElementType());

auto size = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));

auto x = rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(op, resultType,
elementType, size);
return success();
}
};

class LoadOpLowering : public OpConversionPattern<cxx::LoadOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());

rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultType,
adaptor.getAddr());

return success();
}
};

class StoreOpLowering : public OpConversionPattern<cxx::StoreOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto valueType = typeConverter->convertType(op.getValue().getType());
if (!valueType) {
return rewriter.notifyMatchFailure(op,
"failed to convert store value type");
}

rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
adaptor.getAddr());

return success();
}
};

class CxxToFuncLoweringPass
: public PassWrapper<CxxToFuncLoweringPass, OperationPass<ModuleOp>> {
class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToFuncLoweringPass)
using OpConversionPattern::OpConversionPattern;

auto getArgument() const -> StringRef override { return "cxx-to-func"; }
auto matchAndRewrite(cxx::IntConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert integer constant type");
}

auto valueAttr = adaptor.getValueAttr();
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType, valueAttr);

return success();
}
};

class CxxToLLVMLoweringPass
: public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToLLVMLoweringPass)

auto getArgument() const -> StringRef override { return "cxx-to-llvm"; }

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
registry.insert<LLVM::LLVMDialect>();
}

void runOnOperation() final;
};

} // namespace

void CxxToFuncLoweringPass::runOnOperation() {
ConversionTarget target(getContext());
void CxxToLLVMLoweringPass::runOnOperation() {
auto context = &getContext();
auto module = getOperation();

// set up the data layout
mlir::DataLayout dataLayout(module);

// good
target.addLegalDialect<func::FuncDialect>();
// set up the type converter
LLVMTypeConverter typeConverter{context};
typeConverter.addConversion([](cxx::IntegerType type) {
return IntegerType::get(type.getContext(), type.getWidth());
});

// illegal
target.addIllegalOp<cxx::FuncOp>();
target.addIllegalOp<cxx::ReturnOp>();
typeConverter.addConversion([](cxx::PointerType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});

RewritePatternSet patterns(&getContext());
patterns.insert<FuncOpLowering>(&getContext());
patterns.insert<ReturnOpLowering>(&getContext());
// set up the conversion patterns
ConversionTarget target(*context);

auto module = getOperation();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<cxx::CxxDialect>();

RewritePatternSet patterns(context);
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, IntConstantOpLowering>(
typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
typeConverter);

cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
Expand All @@ -101,16 +249,22 @@ void CxxToFuncLoweringPass::runOnOperation() {

} // namespace mlir

auto cxx::createLowerToFuncPass() -> std::unique_ptr<mlir::Pass> {
return std::make_unique<mlir::CxxToFuncLoweringPass>();
auto cxx::createLowerToLLVMPass() -> std::unique_ptr<mlir::Pass> {
return std::make_unique<mlir::CxxToLLVMLoweringPass>();
}

auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult {
mlir::PassManager pm(module->getName());

pm.addPass(cxx::createLowerToFuncPass());
// debug dialect conversions
#if false
module->getContext()->disableMultithreading();
pm.enableIRPrinting();
#endif

pm.addPass(cxx::createLowerToLLVMPass());

if (mlir::failed(pm.run(module))) {
if (failed(pm.run(module))) {
return mlir::failure();
}

Expand Down
2 changes: 1 addition & 1 deletion src/mlir/cxx/mlir/cxx_dialect_conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

namespace cxx {

[[nodiscard]] auto createLowerToFuncPass() -> std::unique_ptr<mlir::Pass>;
[[nodiscard]] auto createLowerToLLVMPass() -> std::unique_ptr<mlir::Pass>;

[[nodiscard]] auto lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult;

Expand Down