Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ def Cxx_MulIOp : Cxx_Op<"muli"> {
let results = (outs Cxx_IntegerType:$result);
}

def CondBranchOp : Cxx_Op<"cond_br", [ AttrSizedOperandSegments, Terminator ]> {
let arguments = (ins Cxx_BoolType:$condition, Variadic<AnyType>:$trueDestOperands, Variadic<AnyType>:$falseDestOperands);

let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
}

//
// todo ops
//
Expand Down
16 changes: 16 additions & 0 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>

#include <format>

namespace cxx {
Expand All @@ -42,6 +45,19 @@ auto Codegen::currentBlockMightHaveTerminator() -> bool {
return block->mightHaveTerminator();
}

auto Codegen::newBlock() -> mlir::Block* {
auto region = builder_.getBlock()->getParent();
auto newBlock = new mlir::Block();
region->getBlocks().push_back(newBlock);
return newBlock;
}

void Codegen::branch(mlir::Location loc, mlir::Block* block,
mlir::ValueRange operands) {
if (currentBlockMightHaveTerminator()) return;
builder_.create<mlir::cf::BranchOp>(loc, block, operands);
}

auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
auto var = symbol_cast<VariableSymbol>(symbol);
if (!var) return std::nullopt;
Expand Down
7 changes: 7 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class Codegen {
ExpressionAST* ast, ExpressionFormat format = ExpressionFormat::kValue)
-> ExpressionResult;

void condition(ExpressionAST* ast, mlir::Block* trueBlock,
mlir::Block* falseBlock);

[[nodiscard]] auto templateParameter(TemplateParameterAST* ast)
-> TemplateParameterResult;

Expand Down Expand Up @@ -256,6 +259,10 @@ class Codegen {
[[nodiscard]] auto findOrCreateLocal(Symbol* symbol)
-> std::optional<mlir::Value>;

[[nodiscard]] auto newBlock() -> mlir::Block*;
void branch(mlir::Location loc, mlir::Block* block,
mlir::ValueRange operands = {});

struct UnitVisitor;
struct DeclarationVisitor;
struct StatementVisitor;
Expand Down
37 changes: 37 additions & 0 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
#include <cxx/translation_unit.h>
#include <cxx/types.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>

namespace cxx {

struct Codegen::ExpressionVisitor {
Expand Down Expand Up @@ -120,6 +123,40 @@ auto Codegen::expression(ExpressionAST* ast, ExpressionFormat format)
return {};
}

void Codegen::condition(ExpressionAST* ast, mlir::Block* trueBlock,
mlir::Block* falseBlock) {
if (!ast) return;

if (auto nested = ast_cast<NestedExpressionAST>(ast)) {
condition(nested->expression, trueBlock, falseBlock);
return;
}

if (auto binop = ast_cast<BinaryExpressionAST>(ast)) {
if (binop->op == TokenKind::T_AMP_AMP) {
auto nextBlock = newBlock();
condition(binop->leftExpression, nextBlock, falseBlock);
builder_.setInsertionPointToEnd(nextBlock);
condition(binop->rightExpression, trueBlock, falseBlock);
return;
}

if (binop->op == TokenKind::T_BAR_BAR) {
auto nextBlock = newBlock();
condition(binop->leftExpression, trueBlock, nextBlock);
builder_.setInsertionPointToEnd(nextBlock);
condition(binop->rightExpression, trueBlock, falseBlock);
return;
}
}

const auto loc = getLocation(ast->firstSourceLocation());
auto value = expression(ast);
builder_.create<mlir::cxx::CondBranchOp>(loc, value.value, mlir::ValueRange{},
mlir::ValueRange{}, trueBlock,
falseBlock);
}

auto Codegen::newInitializer(NewInitializerAST* ast) -> NewInitializerResult {
if (ast) return visit(NewInitializerVisitor{*this}, ast);
return {};
Expand Down
45 changes: 33 additions & 12 deletions src/mlir/cxx/mlir/codegen_statements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,21 @@ void Codegen::StatementVisitor::operator()(CompoundStatementAST* ast) {
}

void Codegen::StatementVisitor::operator()(IfStatementAST* ast) {
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
auto trueBlock = gen.newBlock();
auto falseBlock = gen.newBlock();
auto mergeBlock = gen.newBlock();

#if false
gen.statement(ast->initializer);
auto conditionResult = gen.expression(ast->condition);
gen.condition(ast->condition, trueBlock, falseBlock);

gen.builder_.setInsertionPointToEnd(trueBlock);
gen.statement(ast->statement);
gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), mergeBlock);
gen.builder_.setInsertionPointToEnd(falseBlock);
gen.statement(ast->elseStatement);
#endif
gen.branch(gen.getLocation(ast->elseStatement->lastSourceLocation()),
mergeBlock);
gen.builder_.setInsertionPointToEnd(mergeBlock);
}

void Codegen::StatementVisitor::operator()(ConstevalIfStatementAST* ast) {
Expand All @@ -147,21 +154,35 @@ void Codegen::StatementVisitor::operator()(SwitchStatementAST* ast) {
}

void Codegen::StatementVisitor::operator()(WhileStatementAST* ast) {
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
auto beginLoopBlock = gen.newBlock();
auto bodyLoopBlock = gen.newBlock();
auto endLoopBlock = gen.newBlock();

#if false
auto conditionResult = gen.expression(ast->condition);
gen.branch(gen.getLocation(ast->condition->firstSourceLocation()),
beginLoopBlock);

gen.builder_.setInsertionPointToEnd(beginLoopBlock);
gen.condition(ast->condition, bodyLoopBlock, endLoopBlock);

gen.builder_.setInsertionPointToEnd(bodyLoopBlock);
gen.statement(ast->statement);
#endif

gen.branch(gen.getLocation(ast->statement->lastSourceLocation()),
beginLoopBlock);
gen.builder_.setInsertionPointToEnd(endLoopBlock);
}

void Codegen::StatementVisitor::operator()(DoStatementAST* ast) {
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
auto loopBlock = gen.newBlock();
auto endLoopBlock = gen.newBlock();

#if false
gen.branch(gen.getLocation(ast->statement->firstSourceLocation()), loopBlock);

gen.builder_.setInsertionPointToEnd(loopBlock);
gen.statement(ast->statement);
auto expressionResult = gen.expression(ast->expression);
#endif
gen.condition(ast->expression, loopBlock, endLoopBlock);

gen.builder_.setInsertionPointToEnd(endLoopBlock);
}

void Codegen::StatementVisitor::operator()(ForRangeStatementAST* ast) {
Expand Down
37 changes: 28 additions & 9 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
// mlir
#include <mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlow.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
Expand Down Expand Up @@ -440,6 +440,25 @@ class MulIOpLowering : public OpConversionPattern<cxx::MulIOp> {
}
};

class CondBranchOpLowering : public OpConversionPattern<cxx::CondBranchOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
op, adaptor.getCondition(), op.getTrueDest(),
adaptor.getTrueDestOperands(), op.getFalseDest(),
adaptor.getFalseDestOperands());

return success();
}
};

class CxxToLLVMLoweringPass
: public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
public:
Expand All @@ -461,7 +480,7 @@ void CxxToLLVMLoweringPass::runOnOperation() {
auto module = getOperation();

// set up the data layout
mlir::DataLayout dataLayout(module);
DataLayout dataLayout(module);

// set up the type converter
LLVMTypeConverter typeConverter{context};
Expand Down Expand Up @@ -526,13 +545,13 @@ void CxxToLLVMLoweringPass::runOnOperation() {
target.addIllegalDialect<cxx::CxxDialect>();

RewritePatternSet patterns(context);
patterns
.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
IntConstantOpLowering, FloatConstantOpLowering,
IntToBoolOpLowering, BoolToIntOpLowering, IntegralCastOpLowering,
NotOpLowering, AddIOpLowering, SubIOpLowering, MulIOpLowering>(
typeConverter, context);
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
IntConstantOpLowering, FloatConstantOpLowering,
IntToBoolOpLowering, BoolToIntOpLowering,
IntegralCastOpLowering, NotOpLowering, AddIOpLowering,
SubIOpLowering, MulIOpLowering, CondBranchOpLowering>(
typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
typeConverter);
Expand Down
Loading