diff --git a/src/frontend/cxx/frontend.cc b/src/frontend/cxx/frontend.cc index 4e02d18a..7f1e0fd5 100644 --- a/src/frontend/cxx/frontend.cc +++ b/src/frontend/cxx/frontend.cc @@ -384,7 +384,7 @@ auto runOnFile(const CLI& cli, const std::string& fileName) -> bool { mlir::OpPrintingFlags flags; if (cli.opt_g) { - flags.enableDebugInfo(true, true); + flags.enableDebugInfo(true, false); } ir.module->print(llvm::outs(), flags); } diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index c5aed44c..d0e03cbe 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -170,6 +170,12 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [ let results = (outs Cxx_FloatType:$result); } +def Cxx_IntegralCastOp : Cxx_Op<"cast.integral"> { + let arguments = (ins AnyType:$value); + + let results = (outs AnyType:$result); +} + // // todo ops // diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index fd100581..bbc8fd6d 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -22,6 +22,7 @@ // cxx #include +#include #include #include @@ -41,6 +42,25 @@ auto Codegen::currentBlockMightHaveTerminator() -> bool { return block->mightHaveTerminator(); } +auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional { + auto var = symbol_cast(symbol); + if (!var) return std::nullopt; + + if (auto local = locals_.find(var); local != locals_.end()) { + return local->second; + } + + auto type = convertType(var->type()); + auto ptrType = builder_.getType(type); + + auto loc = getLocation(var->location()); + auto allocaOp = builder_.create(loc, ptrType); + + locals_.emplace(var, allocaOp); + + return allocaOp; +} + auto Codegen::getLocation(SourceLocation location) -> mlir::Location { auto [filename, line, column] = unit_->tokenStartPosition(location); diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index b4d1aa22..7f8fead8 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -246,6 +246,9 @@ class Codegen { [[nodiscard]] auto currentBlockMightHaveTerminator() -> bool; + [[nodiscard]] auto findOrCreateLocal(Symbol* symbol) + -> std::optional; + struct UnitVisitor; struct DeclarationVisitor; struct StatementVisitor; @@ -277,6 +280,7 @@ class Codegen { mlir::Block* exitBlock_ = nullptr; mlir::cxx::AllocaOp exitValue_; std::unordered_map classNames_; + std::unordered_map locals_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index 757b3ad1..9ed2a670 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include // mlir @@ -150,20 +151,44 @@ auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast) -> DeclarationResult { #if false for (auto node : ListView{ast->attributeList}) { - auto value = gen(node); + auto value = gen.attributeSpecifier(node); } for (auto node : ListView{ast->declSpecifierList}) { - auto value = gen(node); + auto value = gen.specifier(node); } for (auto node : ListView{ast->initDeclaratorList}) { - auto value = gen(node); + auto value = gen.initDeclarator(node); } - auto requiresClauseResult = gen(ast->requiresClause); + auto requiresClauseResult = gen.requiresClause(ast->requiresClause); #endif + for (auto node : ListView{ast->initDeclaratorList}) { + auto var = symbol_cast(node->symbol); + if (!var) continue; + if (!node->initializer) continue; + + const auto loc = gen.getLocation(var->location()); + + auto local = gen.findOrCreateLocal(var); + + if (!local.has_value()) { + gen.unit_->error(node->initializer->firstSourceLocation(), + std::format("cannot find local variable '{}'", + to_string(var->name()))); + continue; + } + + auto expressionResult = gen.expression(node->initializer); + + const auto elementType = gen.convertType(var->type()); + + gen.builder_.create(loc, expressionResult.value, + local.value()); + } + return {}; } @@ -319,10 +344,11 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) name += std::format("_{}", ++gen.count_); } - const auto savedInsertionPoint = gen.builder_.saveInsertionPoint(); + auto guard = mlir::OpBuilder::InsertionGuard(gen.builder_); const auto loc = gen.getLocation(ast->symbol->location()); + std::unordered_map locals; auto func = gen.builder_.create(loc, name, funcType); auto entryBlock = &func.front(); auto exitBlock = gen.builder_.createBlock(&func.getBody()); @@ -345,6 +371,7 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) std::swap(gen.function_, func); std::swap(gen.exitBlock_, exitBlock); std::swap(gen.exitValue_, exitValue); + std::swap(gen.locals_, locals); // generate code for the function body auto functionBodyResult = gen.functionBody(ast->functionBody); @@ -383,8 +410,7 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) std::swap(gen.function_, func); std::swap(gen.exitBlock_, exitBlock); std::swap(gen.exitValue_, exitValue); - - gen.builder_.restoreInsertionPoint(savedInsertionPoint); + std::swap(gen.locals_, locals); return {}; } diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 054cf55c..a741c24c 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -22,7 +22,9 @@ // cxx #include +#include #include +#include #include namespace cxx { @@ -255,6 +257,25 @@ auto Codegen::ExpressionVisitor::operator()(NestedExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(IdExpressionAST* ast) -> ExpressionResult { + if (auto local = gen.findOrCreateLocal(ast->symbol)) { + return {local.value()}; + } + + if (auto enumerator = symbol_cast(ast->symbol)) { + auto value = enumerator->value().and_then([&](const ConstValue& value) { + ASTInterpreter interp{gen.unit_}; + return interp.toInt(value); + }); + + if (value.has_value()) { + auto loc = gen.getLocation(ast->firstSourceLocation()); + auto type = gen.convertType(enumerator->type()); + auto op = + gen.builder_.create(loc, type, *value); + return {op}; + } + } + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); @@ -731,19 +752,44 @@ auto Codegen::ExpressionVisitor::operator()(DeleteExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(CastExpressionAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); - -#if false - auto typeIdResult = gen.typeId(ast->typeId); auto expressionResult = gen.expression(ast->expression); -#endif - return {op}; + return expressionResult; } auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast) -> ExpressionResult { + auto loc = gen.getLocation(ast->firstSourceLocation()); + + switch (ast->castKind) { + case ImplicitCastKind::kLValueToRValueConversion: { + // generate a load + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + + auto op = gen.builder_.create(loc, resultType, + expressionResult.value); + + return {op}; + } + + case ImplicitCastKind::kIntegralConversion: + case ImplicitCastKind::kIntegralPromotion: { + // generate a cast + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + + auto op = gen.builder_.create( + loc, resultType, expressionResult.value); + + return {op}; + } + + default: + break; + + } // switch + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); @@ -891,14 +937,12 @@ auto Codegen::ExpressionVisitor::operator()(ConditionExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(EqualInitializerAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); + // auto op = + // gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); -#if false auto expressionResult = gen.expression(ast->expression); -#endif - return {op}; + return expressionResult; } auto Codegen::ExpressionVisitor::operator()(BracedInitListAST* ast) diff --git a/src/mlir/cxx/mlir/codegen_statements.cc b/src/mlir/cxx/mlir/codegen_statements.cc index 82e21f1a..7d8600b4 100644 --- a/src/mlir/cxx/mlir/codegen_statements.cc +++ b/src/mlir/cxx/mlir/codegen_statements.cc @@ -231,11 +231,7 @@ void Codegen::StatementVisitor::operator()(GotoStatementAST* ast) { } void Codegen::StatementVisitor::operator()(DeclarationStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); - -#if false - auto declarationResult = gen(ast->declaration); -#endif + auto declarationResult = gen.declaration(ast->declaration); } void Codegen::StatementVisitor::operator()(TryBlockStatementAST* ast) { diff --git a/src/mlir/cxx/mlir/convert_type.cc b/src/mlir/cxx/mlir/convert_type.cc index b4537334..0306ea23 100644 --- a/src/mlir/cxx/mlir/convert_type.cc +++ b/src/mlir/cxx/mlir/convert_type.cc @@ -355,7 +355,9 @@ auto Codegen::ConvertType::operator()(const OverloadSetType* type) auto Codegen::ConvertType::operator()(const BuiltinVaListType* type) -> mlir::Type { - return getExprType(); + // todo: toolchain specific + auto voidType = gen.builder_.getType(); + return gen.builder_.getType(voidType); } } // namespace cxx diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 32377f5c..292b035b 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -119,6 +119,11 @@ class AllocaOpLowering : public OpConversionPattern { auto resultType = LLVM::LLVMPointerType::get(context); auto elementType = typeConverter->convertType(ptrTy.getElementType()); + if (!elementType) { + return rewriter.notifyMatchFailure( + op, "failed to convert element type of alloca"); + } + auto size = rewriter.create( op.getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); @@ -240,6 +245,58 @@ class FloatConstantOpLowering } }; +class IntegralCastOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::IntegralCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert integral cast type"); + } + + const auto sourceType = dyn_cast(op.getValue().getType()); + const auto targetType = dyn_cast(op.getType()); + const auto isSigned = targetType.getIsSigned(); + + if (sourceType.getWidth() == targetType.getWidth()) { + // no conversion needed, just replace the op with the value + rewriter.replaceOp(op, adaptor.getValue()); + return success(); + } + + if (targetType.getWidth() < sourceType.getWidth()) { + // truncation + if (isSigned) { + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + } else { + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + } + return success(); + } + + // extension + + if (isSigned) { + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + } else { + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + } + + return success(); + } +}; + class CxxToLLVMLoweringPass : public PassWrapper> { public: @@ -328,8 +385,8 @@ void CxxToLLVMLoweringPass::runOnOperation() { RewritePatternSet patterns(context); patterns.insert(typeConverter, - context); + IntConstantOpLowering, FloatConstantOpLowering, + IntegralCastOpLowering>(typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); diff --git a/src/parser/cxx/type_checker.cc b/src/parser/cxx/type_checker.cc index 9df595be..0b4ce931 100644 --- a/src/parser/cxx/type_checker.cc +++ b/src/parser/cxx/type_checker.cc @@ -136,9 +136,12 @@ struct TypeChecker::Visitor { CvQualifiers source) const -> bool; void check_cpp_cast_expression(CppCastExpressionAST* ast); - [[nodiscard]] auto check_static_cast(CppCastExpressionAST* ast) -> bool; - [[nodiscard]] auto check_cast_to_derived(const Type* targetType, - ExpressionAST* expression) -> bool; + + [[nodiscard]] auto check_static_cast(ExpressionAST*& expression, + const Type* targetType) -> bool; + + [[nodiscard]] auto check_cast_to_derived(ExpressionAST* expression, + const Type* targetType) -> bool; void check_addition(BinaryExpressionAST* ast); void check_subtraction(BinaryExpressionAST* ast); @@ -562,11 +565,21 @@ void TypeChecker::Visitor::operator()(PostIncrExpressionAST* ast) { } void TypeChecker::Visitor::operator()(CppCastExpressionAST* ast) { + if (ast->typeId) ast->type = ast->typeId->type; + + if (control()->is_lvalue_reference(ast->type)) { + ast->valueCategory = ValueCategory::kLValue; + } else if (control()->is_rvalue_reference(ast->type)) { + ast->valueCategory = ValueCategory::kXValue; + } else { + ast->valueCategory = ValueCategory::kPrValue; + } + check_cpp_cast_expression(ast); switch (check.unit_->tokenKind(ast->castLoc)) { case TokenKind::T_STATIC_CAST: - if (check_static_cast(ast)) break; + if (check_static_cast(ast->expression, ast->type)) break; error( ast->firstSourceLocation(), std::format("invalid static_cast of '{}' to '{}'", @@ -607,23 +620,20 @@ void TypeChecker::Visitor::check_cpp_cast_expression( } } -auto TypeChecker::Visitor::check_static_cast(CppCastExpressionAST* ast) - -> bool { - if (!ast->typeId) return false; - auto targetType = ast->typeId->type; - +auto TypeChecker::Visitor::check_static_cast(ExpressionAST*& expression, + const Type* targetType) -> bool { if (control()->is_void(targetType)) return true; - if (check_cast_to_derived(targetType, ast->expression)) return true; + if (check_cast_to_derived(expression, targetType)) return true; - const auto cv1 = control()->get_cv_qualifiers(ast->expression->type); + const auto cv1 = control()->get_cv_qualifiers(expression->type); const auto cv2 = control()->get_cv_qualifiers(targetType); if (!check_cv_qualifiers(cv2, cv1)) return false; - if (implicit_conversion(ast->expression, ast->type)) return true; + if (implicit_conversion(expression, targetType)) return true; - auto source = ast->expression; + auto source = expression; (void)ensure_prvalue(source); adjust_cv(source); @@ -637,13 +647,13 @@ auto TypeChecker::Visitor::check_static_cast(CppCastExpressionAST* ast) if (!control()->is_object(targetPtr->elementType())) return false; - ast->expression = source; + expression = source; return true; } -auto TypeChecker::Visitor::check_cast_to_derived(const Type* targetType, - ExpressionAST* expression) +auto TypeChecker::Visitor::check_cast_to_derived(ExpressionAST* expression, + const Type* targetType) -> bool { if (!is_lvalue(expression)) return false; @@ -943,17 +953,20 @@ void TypeChecker::Visitor::operator()(DeleteExpressionAST* ast) { } void TypeChecker::Visitor::operator()(CastExpressionAST* ast) { - if (ast->typeId) { - ast->type = control()->remove_reference(ast->typeId->type); - if (control()->is_lvalue_reference(ast->typeId->type)) - ast->valueCategory = ValueCategory::kLValue; - else if (control()->is_rvalue_reference(ast->typeId->type)) - ast->valueCategory = ValueCategory::kXValue; - else { - ast->valueCategory = ValueCategory::kPrValue; - adjust_cv(ast); - } + if (!ast->typeId) return; + + if (ast->typeId) ast->type = ast->typeId->type; + + if (control()->is_lvalue_reference(ast->type)) { + ast->valueCategory = ValueCategory::kLValue; + } else if (control()->is_rvalue_reference(ast->type)) { + ast->valueCategory = ValueCategory::kXValue; + } else { + ast->valueCategory = ValueCategory::kPrValue; } + + if (check_static_cast(ast->expression, ast->type)) return; + // check the other casts } void TypeChecker::Visitor::operator()(ImplicitCastExpressionAST* ast) {}