Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
let hasVerifier = 0;
}

def Cxx_AllocaOp : Cxx_Op<"alloca"> {
let arguments = (ins);

let results = (outs Cxx_PointerType:$result);
}

def Cxx_LoadOp : Cxx_Op<"load"> {
let arguments = (ins Cxx_PointerType:$addr);

let results = (outs AnyType:$result);
}

def Cxx_StoreOp : Cxx_Op<"store"> {
let arguments = (ins AnyType:$value, Cxx_PointerType:$addr);
}

def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> {
let arguments = (ins StrAttr:$message);
let results = (outs Cxx_ExprType:$result);
Expand Down
2 changes: 2 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class Codegen {
mlir::ModuleOp module_;
mlir::cxx::FuncOp function_;
TranslationUnit* unit_ = nullptr;
mlir::Block* exitBlock_ = nullptr;
mlir::cxx::AllocaOp exitValue_;
int count_ = 0;
};

Expand Down
82 changes: 54 additions & 28 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include <cxx/symbols.h>
#include <cxx/types.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/IR/Block.h>

#include <format>

namespace cxx {
Expand Down Expand Up @@ -267,10 +271,9 @@ auto Codegen::DeclarationVisitor::operator()(OpaqueEnumDeclarationAST* ast)
auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
-> DeclarationResult {
auto functionSymbol = ast->symbol;
auto functionType = type_cast<FunctionType>(functionSymbol->type());
auto returnType = functionType->returnType();

auto exprType = gen.builder_.getType<mlir::cxx::ExprType>();
const auto functionType = type_cast<FunctionType>(functionSymbol->type());
const auto returnType = functionType->returnType();
const auto needsExitValue = !gen.control()->is_void(returnType);

std::vector<mlir::Type> inputTypes;
std::vector<mlir::Type> resultTypes;
Expand All @@ -279,12 +282,11 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
inputTypes.push_back(gen.convertType(paramTy));
}

if (!gen.control()->is_void(functionType->returnType())) {
if (needsExitValue) {
resultTypes.push_back(gen.convertType(returnType));
}

auto funcType = gen.builder_.getFunctionType(inputTypes, resultTypes);
auto loc = gen.builder_.getUnknownLoc();

std::vector<std::string> path;
for (Symbol* symbol = ast->symbol; symbol;
Expand All @@ -309,46 +311,70 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
name += std::format("_{}", ++gen.count_);
}

auto savedInsertionPoint = gen.builder_.saveInsertionPoint();
const auto savedInsertionPoint = gen.builder_.saveInsertionPoint();

auto func = gen.builder_.create<mlir::cxx::FuncOp>(loc, name, funcType);
const auto loc = gen.getLocation(ast->symbol->location());

auto func = gen.builder_.create<mlir::cxx::FuncOp>(loc, name, funcType);
auto entryBlock = &func.front();
auto exitBlock = gen.builder_.createBlock(&func.getBody());
mlir::cxx::AllocaOp exitValue;

// set the insertion point to the entry block
gen.builder_.setInsertionPointToEnd(entryBlock);

std::swap(gen.function_, func);
if (needsExitValue) {
auto exitValueLoc =
gen.getLocation(ast->functionBody->firstSourceLocation());
auto exitValueType = gen.convertType(returnType);
auto ptrType = gen.builder_.getType<mlir::cxx::PointerType>(exitValueType);
exitValue = gen.builder_.create<mlir::cxx::AllocaOp>(exitValueLoc, ptrType);

#if false
for (auto node : ListView{ast->attributeList}) {
auto value = gen(node);
exitBlock->addArgument(ptrType, exitValueLoc);
}

for (auto node : ListView{ast->declSpecifierList}) {
auto value = gen(node);
}

auto declaratorResult = gen(ast->declarator);
auto requiresClauseResult = gen(ast->requiresClause);
#endif
// restore state
std::swap(gen.function_, func);
std::swap(gen.exitBlock_, exitBlock);
std::swap(gen.exitValue_, exitValue);

// generate code for the function body
auto functionBodyResult = gen(ast->functionBody);

std::swap(gen.function_, func);
// terminate the function body

auto endLoc = gen.getLocation(ast->lastSourceLocation());
const auto endLoc = gen.getLocation(ast->lastSourceLocation());

if (gen.control()->is_void(returnType)) {
// If the function returns void, we don't need to return anything.
gen.builder_.create<mlir::cxx::ReturnOp>(endLoc);
if (needsExitValue) {
// We need to return a value of the correct type.

llvm::SmallVector<mlir::Value> exitBlockArgs;
exitBlockArgs.push_back(gen.exitValue_.getResult());

gen.builder_.create<mlir::cf::BranchOp>(endLoc, gen.exitBlock_,
exitBlockArgs);

gen.builder_.setInsertionPointToEnd(gen.exitBlock_);

auto elementType = gen.exitValue_.getType().getElementType();

auto value = gen.builder_.create<mlir::cxx::LoadOp>(endLoc, elementType,
gen.exitValue_);

gen.builder_.create<mlir::cxx::ReturnOp>(endLoc, value->getResults());
} else {
// Otherwise, we need to return a value of the correct type.
auto r = gen.emitTodoExpr(ast->lastSourceLocation(), "result value");
// If the function returns void, we don't need to return anything.

auto result =
gen.builder_.create<mlir::cxx::ReturnOp>(endLoc, r->getResults());
gen.builder_.create<mlir::cf::BranchOp>(endLoc, gen.exitBlock_);
gen.builder_.setInsertionPointToEnd(gen.exitBlock_);
gen.builder_.create<mlir::cxx::ReturnOp>(endLoc);
}

// restore the state
std::swap(gen.function_, func);
std::swap(gen.exitBlock_, exitBlock);
std::swap(gen.exitValue_, exitValue);

gen.builder_.restoreInsertionPoint(savedInsertionPoint);

return {};
Expand Down