diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 2557519b..eb0bbe96 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -144,8 +144,38 @@ def Cxx_LoadOp : Cxx_Op<"load"> { def Cxx_StoreOp : Cxx_Op<"store"> { let arguments = (ins AnyType:$value, Cxx_PointerType:$addr); + + let hasVerifier = 1; +} + +def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [ + Pure +]> { + let arguments = (ins BoolAttr:$value); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_IntConstantOp : Cxx_Op<"constant.int", [ + Pure +]> { + let arguments = (ins I64Attr:$value); + + let results = (outs Cxx_IntegerType:$result); } +def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [ + Pure +]> { + let arguments = (ins F64Attr:$value); + + let results = (outs Cxx_FloatType:$result); +} + +// +// todo ops +// + def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> { let arguments = (ins StrAttr:$message); let results = (outs Cxx_ExprType:$result); diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index a3610979..fd100581 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -35,6 +35,12 @@ Codegen::~Codegen() {} auto Codegen::control() const -> Control* { return unit_->control(); } +auto Codegen::currentBlockMightHaveTerminator() -> bool { + auto block = builder_.getInsertionBlock(); + if (!block) return true; + return block->mightHaveTerminator(); +} + auto Codegen::getLocation(SourceLocation location) -> mlir::Location { auto [filename, line, column] = unit_->tokenStartPosition(location); diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 3498bdc9..b4d1aa22 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -244,6 +244,8 @@ class Codegen { [[nodiscard]] auto convertType(const Type* type) -> mlir::Type; + [[nodiscard]] auto currentBlockMightHaveTerminator() -> bool; + struct UnitVisitor; struct DeclarationVisitor; struct StatementVisitor; diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index 71f3d79a..757b3ad1 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -353,17 +353,21 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) const auto endLoc = gen.getLocation(ast->lastSourceLocation()); - if (needsExitValue) { - // We need to return a value of the correct type. - + if (!gen.builder_.getBlock()->mightHaveTerminator()) { llvm::SmallVector exitBlockArgs; - exitBlockArgs.push_back(gen.exitValue_.getResult()); - gen.builder_.create(endLoc, gen.exitBlock_, - exitBlockArgs); + if (gen.exitValue_) { + exitBlockArgs.push_back(gen.exitValue_.getResult()); + } - gen.builder_.setInsertionPointToEnd(gen.exitBlock_); + gen.builder_.create(endLoc, exitBlockArgs, + gen.exitBlock_); + } + gen.builder_.setInsertionPointToEnd(gen.exitBlock_); + + if (gen.exitValue_) { + // We need to return a value of the correct type. auto elementType = gen.exitValue_.getType().getElementType(); auto value = gen.builder_.create(endLoc, elementType, @@ -372,9 +376,6 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) gen.builder_.create(endLoc, value->getResults()); } else { // If the function returns void, we don't need to return anything. - - gen.builder_.create(endLoc, gen.exitBlock_); - gen.builder_.setInsertionPointToEnd(gen.exitBlock_); gen.builder_.create(endLoc); } diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 7b789e99..4c69018c 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -22,6 +22,7 @@ // cxx #include +#include namespace cxx { @@ -129,37 +130,49 @@ auto Codegen::ExpressionVisitor::operator()(GeneratedLiteralExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(CharLiteralExpressionAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); + auto loc = gen.getLocation(ast->literalLoc); + + auto type = gen.convertType(ast->type); + auto value = gen.builder_.getI64IntegerAttr(ast->literal->charValue()); + + auto op = gen.builder_.create(loc, type, value); + return {op}; } auto Codegen::ExpressionVisitor::operator()(BoolLiteralExpressionAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); + auto loc = gen.getLocation(ast->literalLoc); + + auto type = gen.convertType(ast->type); + auto value = gen.builder_.getBoolAttr(ast->isTrue); + + auto op = gen.builder_.create(loc, type, value); + return {op}; } auto Codegen::ExpressionVisitor::operator()(IntLiteralExpressionAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); - -#if false auto loc = gen.getLocation(ast->literalLoc); - auto op = gen.builder_.create( - loc, ast->literal->integerValue()); -#endif + auto type = gen.convertType(ast->type); + auto value = gen.builder_.getI64IntegerAttr(ast->literal->integerValue()); + + auto op = gen.builder_.create(loc, type, value); return {op}; } auto Codegen::ExpressionVisitor::operator()(FloatLiteralExpressionAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); + auto loc = gen.getLocation(ast->literalLoc); + + auto type = gen.convertType(ast->type); + auto value = gen.builder_.getF64FloatAttr(ast->literal->floatValue()); + + auto op = gen.builder_.create(loc, type, value); + return {op}; } diff --git a/src/mlir/cxx/mlir/codegen_statements.cc b/src/mlir/cxx/mlir/codegen_statements.cc index 227308ae..82e21f1a 100644 --- a/src/mlir/cxx/mlir/codegen_statements.cc +++ b/src/mlir/cxx/mlir/codegen_statements.cc @@ -22,12 +22,18 @@ // cxx #include +#include + +// mlir +#include namespace cxx { struct Codegen::StatementVisitor { Codegen& gen; + [[nodiscard]] auto control() const -> Control* { return gen.control(); } + void operator()(LabeledStatementAST* ast); void operator()(CaseStatementAST* ast); void operator()(DefaultStatementAST* ast); @@ -61,6 +67,9 @@ struct Codegen::ExceptionDeclarationVisitor { void Codegen::statement(StatementAST* ast) { if (!ast) return; + + if (currentBlockMightHaveTerminator()) return; + visit(StatementVisitor{*this}, ast); } @@ -75,6 +84,7 @@ auto Codegen::handler(HandlerAST* ast) -> HandlerResult { auto exceptionDeclarationResult = exceptionDeclaration(ast->exceptionDeclaration); + statement(ast->statement); return {}; @@ -185,11 +195,26 @@ void Codegen::StatementVisitor::operator()(ContinueStatementAST* ast) { } void Codegen::StatementVisitor::operator()(ReturnStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + // (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + + auto value = gen.expression(ast->expression); #if false auto expressionResult = gen.expression(ast->expression); #endif + + auto loc = gen.getLocation(ast->firstSourceLocation()); + + mlir::SmallVector results; + + if (gen.exitValue_) { + gen.builder_.create(loc, value.value, + gen.exitValue_.getResult()); + + results.push_back(gen.exitValue_); + } + + gen.builder_.create(loc, results, gen.exitBlock_); } void Codegen::StatementVisitor::operator()(CoroutineReturnStatementAST* ast) { diff --git a/src/mlir/cxx/mlir/cxx_dialect.cc b/src/mlir/cxx/mlir/cxx_dialect.cc index 0bb035e2..d197d393 100644 --- a/src/mlir/cxx/mlir/cxx_dialect.cc +++ b/src/mlir/cxx/mlir/cxx_dialect.cc @@ -145,6 +145,24 @@ auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult { getResAttrsAttrName(result.name)); } +auto StoreOp::verify() -> LogicalResult { +#if false + auto addrType = dyn_cast(getAddr().getType()); + if (!addrType) { + return emitOpError("addr must be a pointer type"); + } + + auto valueType = getValue().getType(); + if (addrType.getElementType() != valueType) { + return emitOpError("addr must be a pointer to the value type (") + << valueType << " but found " << addrType << ")"; + } + +#endif + + return success(); +} + auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType { return Base::get(context, name); }