diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 8dbbbce5..50162bdf 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -99,6 +99,22 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> { let hasVerifier = 0; } +def Cxx_AllocaOp : Cxx_Op<"alloca"> { + let arguments = (ins); + + let results = (outs Cxx_PointerType:$result); +} + +def Cxx_LoadOp : Cxx_Op<"load"> { + let arguments = (ins Cxx_PointerType:$addr); + + let results = (outs AnyType:$result); +} + +def Cxx_StoreOp : Cxx_Op<"store"> { + let arguments = (ins AnyType:$value, Cxx_PointerType:$addr); +} + def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> { let arguments = (ins StrAttr:$message); let results = (outs Cxx_ExprType:$result); diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 4f9cc141..25e1a5c9 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -203,6 +203,8 @@ class Codegen { mlir::ModuleOp module_; mlir::cxx::FuncOp function_; TranslationUnit* unit_ = nullptr; + mlir::Block* exitBlock_ = nullptr; + mlir::cxx::AllocaOp exitValue_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index 25c5c634..7f550844 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -26,6 +26,10 @@ #include #include +// mlir +#include +#include + #include namespace cxx { @@ -267,10 +271,9 @@ auto Codegen::DeclarationVisitor::operator()(OpaqueEnumDeclarationAST* ast) auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) -> DeclarationResult { auto functionSymbol = ast->symbol; - auto functionType = type_cast(functionSymbol->type()); - auto returnType = functionType->returnType(); - - auto exprType = gen.builder_.getType(); + const auto functionType = type_cast(functionSymbol->type()); + const auto returnType = functionType->returnType(); + const auto needsExitValue = !gen.control()->is_void(returnType); std::vector inputTypes; std::vector resultTypes; @@ -279,12 +282,11 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) inputTypes.push_back(gen.convertType(paramTy)); } - if (!gen.control()->is_void(functionType->returnType())) { + if (needsExitValue) { resultTypes.push_back(gen.convertType(returnType)); } auto funcType = gen.builder_.getFunctionType(inputTypes, resultTypes); - auto loc = gen.builder_.getUnknownLoc(); std::vector path; for (Symbol* symbol = ast->symbol; symbol; @@ -309,46 +311,70 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) name += std::format("_{}", ++gen.count_); } - auto savedInsertionPoint = gen.builder_.saveInsertionPoint(); + const auto savedInsertionPoint = gen.builder_.saveInsertionPoint(); - auto func = gen.builder_.create(loc, name, funcType); + const auto loc = gen.getLocation(ast->symbol->location()); + auto func = gen.builder_.create(loc, name, funcType); auto entryBlock = &func.front(); + auto exitBlock = gen.builder_.createBlock(&func.getBody()); + mlir::cxx::AllocaOp exitValue; + // set the insertion point to the entry block gen.builder_.setInsertionPointToEnd(entryBlock); - std::swap(gen.function_, func); + if (needsExitValue) { + auto exitValueLoc = + gen.getLocation(ast->functionBody->firstSourceLocation()); + auto exitValueType = gen.convertType(returnType); + auto ptrType = gen.builder_.getType(exitValueType); + exitValue = gen.builder_.create(exitValueLoc, ptrType); -#if false - for (auto node : ListView{ast->attributeList}) { - auto value = gen(node); + exitBlock->addArgument(ptrType, exitValueLoc); } - for (auto node : ListView{ast->declSpecifierList}) { - auto value = gen(node); - } - - auto declaratorResult = gen(ast->declarator); - auto requiresClauseResult = gen(ast->requiresClause); -#endif + // restore state + std::swap(gen.function_, func); + std::swap(gen.exitBlock_, exitBlock); + std::swap(gen.exitValue_, exitValue); + // generate code for the function body auto functionBodyResult = gen(ast->functionBody); - std::swap(gen.function_, func); + // terminate the function body - auto endLoc = gen.getLocation(ast->lastSourceLocation()); + const auto endLoc = gen.getLocation(ast->lastSourceLocation()); - if (gen.control()->is_void(returnType)) { - // If the function returns void, we don't need to return anything. - gen.builder_.create(endLoc); + if (needsExitValue) { + // We need to return a value of the correct type. + + llvm::SmallVector exitBlockArgs; + exitBlockArgs.push_back(gen.exitValue_.getResult()); + + gen.builder_.create(endLoc, gen.exitBlock_, + exitBlockArgs); + + gen.builder_.setInsertionPointToEnd(gen.exitBlock_); + + auto elementType = gen.exitValue_.getType().getElementType(); + + auto value = gen.builder_.create(endLoc, elementType, + gen.exitValue_); + + gen.builder_.create(endLoc, value->getResults()); } else { - // Otherwise, we need to return a value of the correct type. - auto r = gen.emitTodoExpr(ast->lastSourceLocation(), "result value"); + // If the function returns void, we don't need to return anything. - auto result = - gen.builder_.create(endLoc, r->getResults()); + gen.builder_.create(endLoc, gen.exitBlock_); + gen.builder_.setInsertionPointToEnd(gen.exitBlock_); + gen.builder_.create(endLoc); } + // restore the state + std::swap(gen.function_, func); + std::swap(gen.exitBlock_, exitBlock); + std::swap(gen.exitValue_, exitValue); + gen.builder_.restoreInsertionPoint(savedInsertionPoint); return {};