diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 02aa8e5b..502372a4 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -263,6 +263,16 @@ class Codegen { void branch(mlir::Location loc, mlir::Block* block, mlir::ValueRange operands = {}); + struct Loop { + mlir::Block* continueBlock = nullptr; + mlir::Block* breakBlock = nullptr; + + Loop() = default; + + Loop(mlir::Block* continueBlock, mlir::Block* breakBlock) + : continueBlock(continueBlock), breakBlock(breakBlock) {} + }; + struct UnitVisitor; struct DeclarationVisitor; struct StatementVisitor; @@ -295,6 +305,7 @@ class Codegen { mlir::cxx::AllocaOp exitValue_; std::unordered_map classNames_; std::unordered_map locals_; + Loop loop_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/codegen_statements.cc b/src/mlir/cxx/mlir/codegen_statements.cc index d3a33713..fb765315 100644 --- a/src/mlir/cxx/mlir/codegen_statements.cc +++ b/src/mlir/cxx/mlir/codegen_statements.cc @@ -129,7 +129,9 @@ void Codegen::StatementVisitor::operator()(IfStatementAST* ast) { gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), mergeBlock); gen.builder_.setInsertionPointToEnd(falseBlock); gen.statement(ast->elseStatement); - gen.branch(gen.getLocation(ast->elseStatement->lastSourceLocation()), + gen.branch(gen.getLocation(ast->elseStatement + ? ast->elseStatement->lastSourceLocation() + : ast->elseLoc), mergeBlock); gen.builder_.setInsertionPointToEnd(mergeBlock); } @@ -158,6 +160,10 @@ void Codegen::StatementVisitor::operator()(WhileStatementAST* ast) { auto bodyLoopBlock = gen.newBlock(); auto endLoopBlock = gen.newBlock(); + Loop loop{beginLoopBlock, endLoopBlock}; + + std::swap(gen.loop_, loop); + gen.branch(gen.getLocation(ast->condition->firstSourceLocation()), beginLoopBlock); @@ -170,12 +176,17 @@ void Codegen::StatementVisitor::operator()(WhileStatementAST* ast) { gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), beginLoopBlock); gen.builder_.setInsertionPointToEnd(endLoopBlock); + + std::swap(gen.loop_, loop); } void Codegen::StatementVisitor::operator()(DoStatementAST* ast) { auto loopBlock = gen.newBlock(); auto endLoopBlock = gen.newBlock(); + Loop loop{loopBlock, endLoopBlock}; + std::swap(gen.loop_, loop); + gen.branch(gen.getLocation(ast->statement->firstSourceLocation()), loopBlock); gen.builder_.setInsertionPointToEnd(loopBlock); @@ -183,6 +194,8 @@ void Codegen::StatementVisitor::operator()(DoStatementAST* ast) { gen.condition(ast->expression, loopBlock, endLoopBlock); gen.builder_.setInsertionPointToEnd(endLoopBlock); + + std::swap(gen.loop_, loop); } void Codegen::StatementVisitor::operator()(ForRangeStatementAST* ast) { @@ -197,21 +210,61 @@ void Codegen::StatementVisitor::operator()(ForRangeStatementAST* ast) { } void Codegen::StatementVisitor::operator()(ForStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); - -#if false gen.statement(ast->initializer); - auto conditionResult = gen.expression(ast->condition); - auto expressionResult = gen.expression(ast->expression); + + auto beginLoopBlock = gen.newBlock(); + auto loopBodyBlock = gen.newBlock(); + auto stepLoopBlock = gen.newBlock(); + auto endLoopBlock = gen.newBlock(); + + Loop loop{stepLoopBlock, endLoopBlock}; + std::swap(gen.loop_, loop); + + gen.branch( + gen.getLocation(ast->condition ? ast->condition->firstSourceLocation() + : ast->semicolonLoc), + beginLoopBlock); + + gen.builder_.setInsertionPointToEnd(beginLoopBlock); + gen.condition(ast->condition, loopBodyBlock, endLoopBlock); + + gen.builder_.setInsertionPointToEnd(loopBodyBlock); gen.statement(ast->statement); -#endif + + gen.branch(gen.getLocation(ast->statement->lastSourceLocation()), + stepLoopBlock); + + gen.builder_.setInsertionPointToEnd(stepLoopBlock); + + (void)gen.expression(ast->expression, ExpressionFormat::kSideEffect); + + gen.branch( + gen.getLocation(ast->expression ? ast->expression->lastSourceLocation() + : ast->rparenLoc), + beginLoopBlock); + + gen.builder_.setInsertionPointToEnd(endLoopBlock); + + std::swap(gen.loop_, loop); } void Codegen::StatementVisitor::operator()(BreakStatementAST* ast) { + if (auto target = gen.loop_.breakBlock) { + gen.builder_.create( + gen.getLocation(ast->firstSourceLocation()), target); + return; + } + (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); } void Codegen::StatementVisitor::operator()(ContinueStatementAST* ast) { + if (auto target = gen.loop_.continueBlock) { + gen.builder_.create( + gen.getLocation(ast->firstSourceLocation()), target); + return; + } + (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); }