From 2223cf1f4c525563754aa3291663dcca23ab108a Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Sun, 3 Aug 2025 16:01:11 +0200 Subject: [PATCH] Add unstructured conditional branches In this change we convert if, while and do statements using unstructured control flow. Obviously, the plan is to produce structured control flow but we will do that later. Signed-off-by: Roberto Raggi --- src/mlir/cxx/mlir/CxxOps.td | 6 +++ src/mlir/cxx/mlir/codegen.cc | 16 +++++++ src/mlir/cxx/mlir/codegen.h | 7 +++ src/mlir/cxx/mlir/codegen_expressions.cc | 37 ++++++++++++++++ src/mlir/cxx/mlir/codegen_statements.cc | 45 ++++++++++++++------ src/mlir/cxx/mlir/cxx_dialect_conversions.cc | 37 ++++++++++++---- 6 files changed, 127 insertions(+), 21 deletions(-) diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 64f181d2..048d1208 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -212,6 +212,12 @@ def Cxx_MulIOp : Cxx_Op<"muli"> { let results = (outs Cxx_IntegerType:$result); } +def CondBranchOp : Cxx_Op<"cond_br", [ AttrSizedOperandSegments, Terminator ]> { + let arguments = (ins Cxx_BoolType:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands); + + let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); +} + // // todo ops // diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index bbc8fd6d..6d8993c1 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -25,6 +25,9 @@ #include #include +// mlir +#include + #include namespace cxx { @@ -42,6 +45,19 @@ auto Codegen::currentBlockMightHaveTerminator() -> bool { return block->mightHaveTerminator(); } +auto Codegen::newBlock() -> mlir::Block* { + auto region = builder_.getBlock()->getParent(); + auto newBlock = new mlir::Block(); + region->getBlocks().push_back(newBlock); + return newBlock; +} + +void Codegen::branch(mlir::Location loc, mlir::Block* block, + mlir::ValueRange operands) { + if (currentBlockMightHaveTerminator()) return; + builder_.create(loc, block, operands); +} + auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional { auto var = symbol_cast(symbol); if (!var) return std::nullopt; diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index ef0feafb..02aa8e5b 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -121,6 +121,9 @@ class Codegen { ExpressionAST* ast, ExpressionFormat format = ExpressionFormat::kValue) -> ExpressionResult; + void condition(ExpressionAST* ast, mlir::Block* trueBlock, + mlir::Block* falseBlock); + [[nodiscard]] auto templateParameter(TemplateParameterAST* ast) -> TemplateParameterResult; @@ -256,6 +259,10 @@ class Codegen { [[nodiscard]] auto findOrCreateLocal(Symbol* symbol) -> std::optional; + [[nodiscard]] auto newBlock() -> mlir::Block*; + void branch(mlir::Location loc, mlir::Block* block, + mlir::ValueRange operands = {}); + struct UnitVisitor; struct DeclarationVisitor; struct StatementVisitor; diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 5834bd86..ca92b005 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -29,6 +29,9 @@ #include #include +// mlir +#include + namespace cxx { struct Codegen::ExpressionVisitor { @@ -120,6 +123,40 @@ auto Codegen::expression(ExpressionAST* ast, ExpressionFormat format) return {}; } +void Codegen::condition(ExpressionAST* ast, mlir::Block* trueBlock, + mlir::Block* falseBlock) { + if (!ast) return; + + if (auto nested = ast_cast(ast)) { + condition(nested->expression, trueBlock, falseBlock); + return; + } + + if (auto binop = ast_cast(ast)) { + if (binop->op == TokenKind::T_AMP_AMP) { + auto nextBlock = newBlock(); + condition(binop->leftExpression, nextBlock, falseBlock); + builder_.setInsertionPointToEnd(nextBlock); + condition(binop->rightExpression, trueBlock, falseBlock); + return; + } + + if (binop->op == TokenKind::T_BAR_BAR) { + auto nextBlock = newBlock(); + condition(binop->leftExpression, trueBlock, nextBlock); + builder_.setInsertionPointToEnd(nextBlock); + condition(binop->rightExpression, trueBlock, falseBlock); + return; + } + } + + const auto loc = getLocation(ast->firstSourceLocation()); + auto value = expression(ast); + builder_.create(loc, value.value, mlir::ValueRange{}, + mlir::ValueRange{}, trueBlock, + falseBlock); +} + auto Codegen::newInitializer(NewInitializerAST* ast) -> NewInitializerResult { if (ast) return visit(NewInitializerVisitor{*this}, ast); return {}; diff --git a/src/mlir/cxx/mlir/codegen_statements.cc b/src/mlir/cxx/mlir/codegen_statements.cc index 79944777..d3a33713 100644 --- a/src/mlir/cxx/mlir/codegen_statements.cc +++ b/src/mlir/cxx/mlir/codegen_statements.cc @@ -117,14 +117,21 @@ void Codegen::StatementVisitor::operator()(CompoundStatementAST* ast) { } void Codegen::StatementVisitor::operator()(IfStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto trueBlock = gen.newBlock(); + auto falseBlock = gen.newBlock(); + auto mergeBlock = gen.newBlock(); -#if false gen.statement(ast->initializer); - auto conditionResult = gen.expression(ast->condition); + gen.condition(ast->condition, trueBlock, falseBlock); + + gen.builder_.setInsertionPointToEnd(trueBlock); gen.statement(ast->statement); + gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), mergeBlock); + gen.builder_.setInsertionPointToEnd(falseBlock); gen.statement(ast->elseStatement); -#endif + gen.branch(gen.getLocation(ast->elseStatement->lastSourceLocation()), + mergeBlock); + gen.builder_.setInsertionPointToEnd(mergeBlock); } void Codegen::StatementVisitor::operator()(ConstevalIfStatementAST* ast) { @@ -147,21 +154,35 @@ void Codegen::StatementVisitor::operator()(SwitchStatementAST* ast) { } void Codegen::StatementVisitor::operator()(WhileStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto beginLoopBlock = gen.newBlock(); + auto bodyLoopBlock = gen.newBlock(); + auto endLoopBlock = gen.newBlock(); -#if false - auto conditionResult = gen.expression(ast->condition); + gen.branch(gen.getLocation(ast->condition->firstSourceLocation()), + beginLoopBlock); + + gen.builder_.setInsertionPointToEnd(beginLoopBlock); + gen.condition(ast->condition, bodyLoopBlock, endLoopBlock); + + gen.builder_.setInsertionPointToEnd(bodyLoopBlock); gen.statement(ast->statement); -#endif + + gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), + beginLoopBlock); + gen.builder_.setInsertionPointToEnd(endLoopBlock); } void Codegen::StatementVisitor::operator()(DoStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto loopBlock = gen.newBlock(); + auto endLoopBlock = gen.newBlock(); -#if false + gen.branch(gen.getLocation(ast->statement->firstSourceLocation()), loopBlock); + + gen.builder_.setInsertionPointToEnd(loopBlock); gen.statement(ast->statement); - auto expressionResult = gen.expression(ast->expression); -#endif + gen.condition(ast->expression, loopBlock, endLoopBlock); + + gen.builder_.setInsertionPointToEnd(endLoopBlock); } void Codegen::StatementVisitor::operator()(ForRangeStatementAST* ast) { diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 39fa7fba..ca21c7cf 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -26,7 +26,7 @@ // mlir #include #include -#include +#include #include #include #include @@ -440,6 +440,25 @@ class MulIOpLowering : public OpConversionPattern { } }; +class CondBranchOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + + return success(); + } +}; + class CxxToLLVMLoweringPass : public PassWrapper> { public: @@ -461,7 +480,7 @@ void CxxToLLVMLoweringPass::runOnOperation() { auto module = getOperation(); // set up the data layout - mlir::DataLayout dataLayout(module); + DataLayout dataLayout(module); // set up the type converter LLVMTypeConverter typeConverter{context}; @@ -526,13 +545,13 @@ void CxxToLLVMLoweringPass::runOnOperation() { target.addIllegalDialect(); RewritePatternSet patterns(context); - patterns - .insert( - typeConverter, context); + patterns.insert( + typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);