Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,38 @@ def Cxx_LoadOp : Cxx_Op<"load"> {

def Cxx_StoreOp : Cxx_Op<"store"> {
let arguments = (ins AnyType:$value, Cxx_PointerType:$addr);

let hasVerifier = 1;
}

def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [
Pure
]> {
let arguments = (ins BoolAttr:$value);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_IntConstantOp : Cxx_Op<"constant.int", [
Pure
]> {
let arguments = (ins I64Attr:$value);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
Pure
]> {
let arguments = (ins F64Attr:$value);

let results = (outs Cxx_FloatType:$result);
}

//
// todo ops
//

def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> {
let arguments = (ins StrAttr:$message);
let results = (outs Cxx_ExprType:$result);
Expand Down
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ Codegen::~Codegen() {}

auto Codegen::control() const -> Control* { return unit_->control(); }

auto Codegen::currentBlockMightHaveTerminator() -> bool {
auto block = builder_.getInsertionBlock();
if (!block) return true;
return block->mightHaveTerminator();
}

auto Codegen::getLocation(SourceLocation location) -> mlir::Location {
auto [filename, line, column] = unit_->tokenStartPosition(location);

Expand Down
2 changes: 2 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class Codegen {

[[nodiscard]] auto convertType(const Type* type) -> mlir::Type;

[[nodiscard]] auto currentBlockMightHaveTerminator() -> bool;

struct UnitVisitor;
struct DeclarationVisitor;
struct StatementVisitor;
Expand Down
21 changes: 11 additions & 10 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,21 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)

const auto endLoc = gen.getLocation(ast->lastSourceLocation());

if (needsExitValue) {
// We need to return a value of the correct type.

if (!gen.builder_.getBlock()->mightHaveTerminator()) {
llvm::SmallVector<mlir::Value> exitBlockArgs;
exitBlockArgs.push_back(gen.exitValue_.getResult());

gen.builder_.create<mlir::cf::BranchOp>(endLoc, gen.exitBlock_,
exitBlockArgs);
if (gen.exitValue_) {
exitBlockArgs.push_back(gen.exitValue_.getResult());
}

gen.builder_.setInsertionPointToEnd(gen.exitBlock_);
gen.builder_.create<mlir::cf::BranchOp>(endLoc, exitBlockArgs,
gen.exitBlock_);
}

gen.builder_.setInsertionPointToEnd(gen.exitBlock_);

if (gen.exitValue_) {
// We need to return a value of the correct type.
auto elementType = gen.exitValue_.getType().getElementType();

auto value = gen.builder_.create<mlir::cxx::LoadOp>(endLoc, elementType,
Expand All @@ -372,9 +376,6 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
gen.builder_.create<mlir::cxx::ReturnOp>(endLoc, value->getResults());
} else {
// If the function returns void, we don't need to return anything.

gen.builder_.create<mlir::cf::BranchOp>(endLoc, gen.exitBlock_);
gen.builder_.setInsertionPointToEnd(gen.exitBlock_);
gen.builder_.create<mlir::cxx::ReturnOp>(endLoc);
}

Expand Down
39 changes: 26 additions & 13 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

// cxx
#include <cxx/ast.h>
#include <cxx/literals.h>

namespace cxx {

Expand Down Expand Up @@ -129,37 +130,49 @@ auto Codegen::ExpressionVisitor::operator()(GeneratedLiteralExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(CharLiteralExpressionAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getI64IntegerAttr(ast->literal->charValue());

auto op = gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, value);

return {op};
}

auto Codegen::ExpressionVisitor::operator()(BoolLiteralExpressionAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getBoolAttr(ast->isTrue);

auto op = gen.builder_.create<mlir::cxx::BoolConstantOp>(loc, type, value);

return {op};
}

auto Codegen::ExpressionVisitor::operator()(IntLiteralExpressionAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto loc = gen.getLocation(ast->literalLoc);

auto op = gen.builder_.create<mlir::cxx::IntLiteralOp>(
loc, ast->literal->integerValue());
#endif
auto type = gen.convertType(ast->type);
auto value = gen.builder_.getI64IntegerAttr(ast->literal->integerValue());

auto op = gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, value);

return {op};
}

auto Codegen::ExpressionVisitor::operator()(FloatLiteralExpressionAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());

auto op = gen.builder_.create<mlir::cxx::FloatConstantOp>(loc, type, value);

return {op};
}

Expand Down
27 changes: 26 additions & 1 deletion src/mlir/cxx/mlir/codegen_statements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@

// cxx
#include <cxx/ast.h>
#include <cxx/control.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>

namespace cxx {

struct Codegen::StatementVisitor {
Codegen& gen;

[[nodiscard]] auto control() const -> Control* { return gen.control(); }

void operator()(LabeledStatementAST* ast);
void operator()(CaseStatementAST* ast);
void operator()(DefaultStatementAST* ast);
Expand Down Expand Up @@ -61,6 +67,9 @@ struct Codegen::ExceptionDeclarationVisitor {

void Codegen::statement(StatementAST* ast) {
if (!ast) return;

if (currentBlockMightHaveTerminator()) return;

visit(StatementVisitor{*this}, ast);
}

Expand All @@ -75,6 +84,7 @@ auto Codegen::handler(HandlerAST* ast) -> HandlerResult {

auto exceptionDeclarationResult =
exceptionDeclaration(ast->exceptionDeclaration);

statement(ast->statement);

return {};
Expand Down Expand Up @@ -185,11 +195,26 @@ void Codegen::StatementVisitor::operator()(ContinueStatementAST* ast) {
}

void Codegen::StatementVisitor::operator()(ReturnStatementAST* ast) {
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
// (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));

auto value = gen.expression(ast->expression);

#if false
auto expressionResult = gen.expression(ast->expression);
#endif

auto loc = gen.getLocation(ast->firstSourceLocation());

mlir::SmallVector<mlir::Value> results;

if (gen.exitValue_) {
gen.builder_.create<mlir::cxx::StoreOp>(loc, value.value,
gen.exitValue_.getResult());

results.push_back(gen.exitValue_);
}

gen.builder_.create<mlir::cf::BranchOp>(loc, results, gen.exitBlock_);
}

void Codegen::StatementVisitor::operator()(CoroutineReturnStatementAST* ast) {
Expand Down
18 changes: 18 additions & 0 deletions src/mlir/cxx/mlir/cxx_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult {
getResAttrsAttrName(result.name));
}

auto StoreOp::verify() -> LogicalResult {
#if false
auto addrType = dyn_cast<PointerType>(getAddr().getType());
if (!addrType) {
return emitOpError("addr must be a pointer type");
}

auto valueType = getValue().getType();
if (addrType.getElementType() != valueType) {
return emitOpError("addr must be a pointer to the value type (")
<< valueType << " but found " << addrType << ")";
}

#endif

return success();
}

auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType {
return Base::get(context, name);
}
Expand Down