diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 64f181d2..048d1208 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -212,6 +212,12 @@ def Cxx_MulIOp : Cxx_Op<"muli"> { let results = (outs Cxx_IntegerType:$result); } +def CondBranchOp : Cxx_Op<"cond_br", [ AttrSizedOperandSegments, Terminator ]> { + let arguments = (ins Cxx_BoolType:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands); + + let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); +} + // // todo ops // diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index bbc8fd6d..6d8993c1 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -25,6 +25,9 @@ #include #include +// mlir +#include + #include namespace cxx { @@ -42,6 +45,19 @@ auto Codegen::currentBlockMightHaveTerminator() -> bool { return block->mightHaveTerminator(); } +auto Codegen::newBlock() -> mlir::Block* { + auto region = builder_.getBlock()->getParent(); + auto newBlock = new mlir::Block(); + region->getBlocks().push_back(newBlock); + return newBlock; +} + +void Codegen::branch(mlir::Location loc, mlir::Block* block, + mlir::ValueRange operands) { + if (currentBlockMightHaveTerminator()) return; + builder_.create(loc, block, operands); +} + auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional { auto var = symbol_cast(symbol); if (!var) return std::nullopt; diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index ef0feafb..02aa8e5b 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -121,6 +121,9 @@ class Codegen { ExpressionAST* ast, ExpressionFormat format = ExpressionFormat::kValue) -> ExpressionResult; + void condition(ExpressionAST* ast, mlir::Block* trueBlock, + mlir::Block* falseBlock); + [[nodiscard]] auto templateParameter(TemplateParameterAST* ast) -> TemplateParameterResult; @@ -256,6 +259,10 @@ class Codegen { [[nodiscard]] auto findOrCreateLocal(Symbol* symbol) -> std::optional; + [[nodiscard]] auto newBlock() -> mlir::Block*; + void branch(mlir::Location loc, mlir::Block* block, + mlir::ValueRange operands = {}); + struct UnitVisitor; struct DeclarationVisitor; struct StatementVisitor; diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 5834bd86..ca92b005 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -29,6 +29,9 @@ #include #include +// mlir +#include + namespace cxx { struct Codegen::ExpressionVisitor { @@ -120,6 +123,40 @@ auto Codegen::expression(ExpressionAST* ast, ExpressionFormat format) return {}; } +void Codegen::condition(ExpressionAST* ast, mlir::Block* trueBlock, + mlir::Block* falseBlock) { + if (!ast) return; + + if (auto nested = ast_cast(ast)) { + condition(nested->expression, trueBlock, falseBlock); + return; + } + + if (auto binop = ast_cast(ast)) { + if (binop->op == TokenKind::T_AMP_AMP) { + auto nextBlock = newBlock(); + condition(binop->leftExpression, nextBlock, falseBlock); + builder_.setInsertionPointToEnd(nextBlock); + condition(binop->rightExpression, trueBlock, falseBlock); + return; + } + + if (binop->op == TokenKind::T_BAR_BAR) { + auto nextBlock = newBlock(); + condition(binop->leftExpression, trueBlock, nextBlock); + builder_.setInsertionPointToEnd(nextBlock); + condition(binop->rightExpression, trueBlock, falseBlock); + return; + } + } + + const auto loc = getLocation(ast->firstSourceLocation()); + auto value = expression(ast); + builder_.create(loc, value.value, mlir::ValueRange{}, + mlir::ValueRange{}, trueBlock, + falseBlock); +} + auto Codegen::newInitializer(NewInitializerAST* ast) -> NewInitializerResult { if (ast) return visit(NewInitializerVisitor{*this}, ast); return {}; diff --git a/src/mlir/cxx/mlir/codegen_statements.cc b/src/mlir/cxx/mlir/codegen_statements.cc index 79944777..d3a33713 100644 --- a/src/mlir/cxx/mlir/codegen_statements.cc +++ b/src/mlir/cxx/mlir/codegen_statements.cc @@ -117,14 +117,21 @@ void Codegen::StatementVisitor::operator()(CompoundStatementAST* ast) { } void Codegen::StatementVisitor::operator()(IfStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto trueBlock = gen.newBlock(); + auto falseBlock = gen.newBlock(); + auto mergeBlock = gen.newBlock(); -#if false gen.statement(ast->initializer); - auto conditionResult = gen.expression(ast->condition); + gen.condition(ast->condition, trueBlock, falseBlock); + + gen.builder_.setInsertionPointToEnd(trueBlock); gen.statement(ast->statement); + gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), mergeBlock); + gen.builder_.setInsertionPointToEnd(falseBlock); gen.statement(ast->elseStatement); -#endif + gen.branch(gen.getLocation(ast->elseStatement->lastSourceLocation()), + mergeBlock); + gen.builder_.setInsertionPointToEnd(mergeBlock); } void Codegen::StatementVisitor::operator()(ConstevalIfStatementAST* ast) { @@ -147,21 +154,35 @@ void Codegen::StatementVisitor::operator()(SwitchStatementAST* ast) { } void Codegen::StatementVisitor::operator()(WhileStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto beginLoopBlock = gen.newBlock(); + auto bodyLoopBlock = gen.newBlock(); + auto endLoopBlock = gen.newBlock(); -#if false - auto conditionResult = gen.expression(ast->condition); + gen.branch(gen.getLocation(ast->condition->firstSourceLocation()), + beginLoopBlock); + + gen.builder_.setInsertionPointToEnd(beginLoopBlock); + gen.condition(ast->condition, bodyLoopBlock, endLoopBlock); + + gen.builder_.setInsertionPointToEnd(bodyLoopBlock); gen.statement(ast->statement); -#endif + + gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), + beginLoopBlock); + gen.builder_.setInsertionPointToEnd(endLoopBlock); } void Codegen::StatementVisitor::operator()(DoStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto loopBlock = gen.newBlock(); + auto endLoopBlock = gen.newBlock(); -#if false + gen.branch(gen.getLocation(ast->statement->firstSourceLocation()), loopBlock); + + gen.builder_.setInsertionPointToEnd(loopBlock); gen.statement(ast->statement); - auto expressionResult = gen.expression(ast->expression); -#endif + gen.condition(ast->expression, loopBlock, endLoopBlock); + + gen.builder_.setInsertionPointToEnd(endLoopBlock); } void Codegen::StatementVisitor::operator()(ForRangeStatementAST* ast) { diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 39fa7fba..ca21c7cf 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -26,7 +26,7 @@ // mlir #include #include -#include +#include #include #include #include @@ -440,6 +440,25 @@ class MulIOpLowering : public OpConversionPattern { } }; +class CondBranchOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + + return success(); + } +}; + class CxxToLLVMLoweringPass : public PassWrapper> { public: @@ -461,7 +480,7 @@ void CxxToLLVMLoweringPass::runOnOperation() { auto module = getOperation(); // set up the data layout - mlir::DataLayout dataLayout(module); + DataLayout dataLayout(module); // set up the type converter LLVMTypeConverter typeConverter{context}; @@ -526,13 +545,13 @@ void CxxToLLVMLoweringPass::runOnOperation() { target.addIllegalDialect(); RewritePatternSet patterns(context); - patterns - .insert( - typeConverter, context); + patterns.insert( + typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);