From 8de3694f5f21aeac9504890df52e2adb0b260778 Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Wed, 19 Mar 2025 19:59:36 +0000 Subject: [PATCH] Substitute type parameters with type aliases Enable customization of the template instantiation process in unit tests. --- src/parser/cxx/ast_rewriter.cc | 12 +++ src/parser/cxx/binder.cc | 2 + src/parser/cxx/decl_specs.cc | 15 --- src/parser/cxx/parser.cc | 40 ++++---- src/parser/cxx/parser.h | 4 +- src/parser/cxx/parser_fwd.h | 1 + src/parser/cxx/translation_unit.cc | 8 +- src/parser/cxx/translation_unit.h | 3 + tests/api_tests/test_rewriter.cc | 145 +++++++++++++---------------- tests/api_tests/test_utils.h | 13 ++- 10 files changed, 121 insertions(+), 122 deletions(-) diff --git a/src/parser/cxx/ast_rewriter.cc b/src/parser/cxx/ast_rewriter.cc index aacaab48..4af2a5ad 100644 --- a/src/parser/cxx/ast_rewriter.cc +++ b/src/parser/cxx/ast_rewriter.cc @@ -3457,6 +3457,18 @@ auto ASTRewriter::SpecifierVisitor::operator()(NamedTypeSpecifierAST* ast) copy->isTemplateIntroduced = ast->isTemplateIntroduced; copy->symbol = ast->symbol; + if (auto typeParameter = symbol_cast(copy->symbol)) { + const auto& args = rewrite.templateArguments_; + if (typeParameter && typeParameter->depth() == 0 && + typeParameter->index() < args.size()) { + auto index = typeParameter->index(); + + if (auto sym = std::get_if(&args[index])) { + copy->symbol = *sym; + } + } + } + return copy; } diff --git a/src/parser/cxx/binder.cc b/src/parser/cxx/binder.cc index 2e97aa70..725f1000 100644 --- a/src/parser/cxx/binder.cc +++ b/src/parser/cxx/binder.cc @@ -661,6 +661,8 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier, } auto Binder::instantiate(SimpleTemplateIdAST* templateId) -> Symbol* { + if (!translationUnit()->config().templateInstantiation) return nullptr; + std::vector args; for (auto it = templateId->templateArgumentList; it; it = it->next) { if (auto arg = ast_cast(it->value)) { diff --git a/src/parser/cxx/decl_specs.cc b/src/parser/cxx/decl_specs.cc index 53cb9997..074d709e 100644 --- a/src/parser/cxx/decl_specs.cc +++ b/src/parser/cxx/decl_specs.cc @@ -256,21 +256,6 @@ void DeclSpecs::Visitor::operator()(ComplexTypeSpecifierAST* ast) { void DeclSpecs::Visitor::operator()(NamedTypeSpecifierAST* ast) { specs.typeSpecifier = ast; - if (specs.rewriter) { - auto typeParameter = symbol_cast(ast->symbol); - const auto& args = specs.rewriter->templateArguments(); - - if (typeParameter && typeParameter->depth() == 0 && - typeParameter->index() < args.size()) { - auto index = typeParameter->index(); - - if (auto ty = std::get_if(&args[index])) { - specs.type = *ty; - return; - } - } - } - if (ast->symbol) specs.type = ast->symbol->type(); else diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index a645adbd..aa832aa3 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -268,10 +268,8 @@ auto Parser::expect(TokenKind tk, SourceLocation& location) -> bool { void Parser::operator()(UnitAST*& ast) { parse(ast); } -auto Parser::config() const -> const ParserConfiguration& { return config_; } - -void Parser::setConfig(ParserConfiguration config) { - config_ = std::move(config); +auto Parser::config() const -> const ParserConfiguration& { + return unit->config(); } void Parser::parse(UnitAST*& ast) { parse_translation_unit(ast); } @@ -676,7 +674,7 @@ auto Parser::parse_completion(SourceLocation& loc) -> bool { if (didAcceptCompletionToken_) return false; // if there is no completer, return false - if (!config_.complete) return false; + if (!config().complete) return false; if (!match(TokenKind::T_CODE_COMPLETION, loc)) return false; @@ -709,7 +707,7 @@ auto Parser::parse_primary_expression(ExpressionAST*& yyast, } auto Parser::parse_splicer(SplicerAST*& yyast) -> bool { - if (!config_.reflect) return false; + if (!config().reflect) return false; if (!lookat(TokenKind::T_LBRACKET, TokenKind::T_COLON)) return false; @@ -729,7 +727,7 @@ auto Parser::parse_splicer(SplicerAST*& yyast) -> bool { auto Parser::parse_splicer_expression(ExpressionAST*& yyast, const ExprContext& ctx) -> bool { - if (!config_.reflect) return false; + if (!config().reflect) return false; SplicerAST* splicer = nullptr; if (!parse_splicer(splicer)) return false; @@ -741,7 +739,7 @@ auto Parser::parse_splicer_expression(ExpressionAST*& yyast, auto Parser::parse_reflect_expression(ExpressionAST*& yyast, const ExprContext& ctx) -> bool { - if (!config_.reflect) return false; + if (!config().reflect) return false; SourceLocation caretLoc; @@ -1054,7 +1052,7 @@ auto Parser::parse_template_nested_name_specifier( } } - if (!ast->symbol && config_.checkTypes) { + if (!ast->symbol && config().checkTypes) { ast->symbol = binder_.instantiate(templateId); } @@ -1859,7 +1857,7 @@ auto Parser::parse_member_expression(ExpressionAST*& yyast) -> bool { auto objectType = ast->baseExpression->type; // trigger the completion - config_.complete(MemberCompletionContext{ + config().complete(MemberCompletionContext{ .objectType = objectType, .accessOp = ast->accessOp, }); @@ -4082,7 +4080,7 @@ auto Parser::parse_simple_declaration( if (auto scope = decl.getScope()) { setScope(scope); - } else if (q && config_.checkTypes) { + } else if (q && config().checkTypes) { parse_error(q->firstSourceLocation(), std::format("unresolved class or namespace")); } @@ -4091,7 +4089,7 @@ auto Parser::parse_simple_declaration( auto functionSymbol = getFunction(scope(), functionName, functionType); if (!functionSymbol) { - if (q && config_.checkTypes) { + if (q && config().checkTypes) { parse_error(q->firstSourceLocation(), std::format("class or namespace has no member named '{}'", to_string(functionName))); @@ -4196,7 +4194,7 @@ auto Parser::parse_notypespec_function_definition( if (auto scope = decl.getScope()) { setScope(scope); } else if (auto q = decl.getNestedNameSpecifier()) { - if (config_.checkTypes) { + if (config().checkTypes) { parse_error(q->firstSourceLocation(), std::format("unresolved class or namespace")); } @@ -4325,7 +4323,7 @@ auto Parser::parse_static_assert_declaration(DeclarationAST*& yyast) -> bool { value = visit(to_bool, *constValue); } - if (!value && config_.checkTypes) { + if (!value && config().checkTypes) { SourceLocation loc = ast->firstSourceLocation(); if (!ast->expression || !constValue.has_value()) { @@ -4847,12 +4845,12 @@ auto Parser::parse_named_type_specifier(SpecifierAST*& yyast, DeclSpecs& specs) if (conceptSymbol && !lookat(TokenKind::T_AUTO)) return false; } - const auto canInstantiate = config_.checkTypes; + const auto canInstantiate = config().checkTypes; auto symbol = binder_.resolve(nestedNameSpecifier, unqualifiedId, canInstantiate); - if (config_.checkTypes && !symbol && ast_cast(unqualifiedId)) { + if (config().checkTypes && !symbol && ast_cast(unqualifiedId)) { return false; } @@ -4987,14 +4985,14 @@ auto Parser::parse_primitive_type_specifier(SpecifierAST*& yyast, } auto Parser::maybe_template_name(const Identifier* id) -> bool { - if (!config_.fuzzyTemplateResolution) return true; + if (!config().fuzzyTemplateResolution) return true; if (template_names_.contains(id)) return true; if (concept_names_.contains(id)) return true; return false; } void Parser::mark_maybe_template_name(const Identifier* id) { - if (!config_.fuzzyTemplateResolution) return; + if (!config().fuzzyTemplateResolution) return; if (!id) return; template_names_.insert(id); } @@ -5314,7 +5312,7 @@ auto Parser::parse_declarator(DeclaratorAST*& yyast, Decl& decl, if (auto scope = decl.getScope()) { setScope(scope); - } else if (q && config_.checkTypes) { + } else if (q && config().checkTypes) { parse_error(q->firstSourceLocation(), std::format("unresolved class or namespace")); } @@ -9041,7 +9039,7 @@ auto Parser::parse_concept_definition(DeclarationAST*& yyast) -> bool { auto Parser::parse_splicer_specifier(SpecifierAST*& yyast, DeclSpecs& specs) -> bool { - if (!config_.reflect) return false; + if (!config().reflect) return false; if (specs.typeSpecifier) return false; LookaheadParser lookahead{this}; SourceLocation typenameLoc; @@ -9391,7 +9389,7 @@ void Parser::check(ExpressionAST* ast) { if (binder_.inTemplate()) return; TypeChecker check{unit}; check.setScope(scope()); - check.setReportErrors(config_.checkTypes); + check.setReportErrors(config().checkTypes); check(ast); } diff --git a/src/parser/cxx/parser.h b/src/parser/cxx/parser.h index 4c090474..38f71109 100644 --- a/src/parser/cxx/parser.h +++ b/src/parser/cxx/parser.h @@ -69,7 +69,6 @@ class Parser final { void operator()(UnitAST*& ast); [[nodiscard]] auto config() const -> const ParserConfiguration&; - void setConfig(ParserConfiguration config); private: struct TemplateHeadContext; @@ -111,7 +110,7 @@ class Parser final { [[nodiscard]] auto shouldStopParsing() const -> bool { if (didAcceptCompletionToken_) return true; - if (config_.stopParsingPredicate) return config_.stopParsingPredicate(); + if (config().stopParsingPredicate) return config().stopParsingPredicate(); return false; } @@ -804,7 +803,6 @@ class Parser final { DiagnosticsClient* diagnosticClient_ = nullptr; Scope* globalScope_ = nullptr; Binder binder_; - ParserConfiguration config_{}; bool skipFunctionBody_ = false; bool moduleUnit_ = false; const Identifier* moduleId_ = nullptr; diff --git a/src/parser/cxx/parser_fwd.h b/src/parser/cxx/parser_fwd.h index 8da87191..671ed47b 100644 --- a/src/parser/cxx/parser_fwd.h +++ b/src/parser/cxx/parser_fwd.h @@ -47,6 +47,7 @@ using CodeCompletionContext = struct ParserConfiguration { bool checkTypes = false; bool fuzzyTemplateResolution = false; + bool templateInstantiation = true; bool reflect = true; std::function stopParsingPredicate; std::function complete; diff --git a/src/parser/cxx/translation_unit.cc b/src/parser/cxx/translation_unit.cc index b31e8a12..29a73791 100644 --- a/src/parser/cxx/translation_unit.cc +++ b/src/parser/cxx/translation_unit.cc @@ -167,12 +167,18 @@ void TranslationUnit::parse(ParserConfiguration config) { if (ast_) { cxx_runtime_error("translation unit already parsed"); } + + config_ = std::move(config); + preprocessor_->squeeze(); Parser parse(this); - parse.setConfig(std::move(config)); parse(ast_); } +auto TranslationUnit::config() const -> const ParserConfiguration& { + return config_; +} + auto TranslationUnit::globalScope() const -> Scope* { if (!globalNamespace_) return nullptr; return globalNamespace_->scope(); diff --git a/src/parser/cxx/translation_unit.h b/src/parser/cxx/translation_unit.h index a025dbe5..97e4fca1 100644 --- a/src/parser/cxx/translation_unit.h +++ b/src/parser/cxx/translation_unit.h @@ -68,6 +68,8 @@ class TranslationUnit { void parse(ParserConfiguration config = {}); + [[nodiscard]] auto config() const -> const ParserConfiguration&; + // set source and preprocess, deprecated. void setSource(std::string source, std::string fileName); @@ -137,6 +139,7 @@ class TranslationUnit { const char* yyptr = nullptr; DiagnosticsClient* diagnosticsClient_ = nullptr; NamespaceSymbol* globalNamespace_ = nullptr; + ParserConfiguration config_; }; } // namespace cxx diff --git a/tests/api_tests/test_rewriter.cc b/tests/api_tests/test_rewriter.cc index 142f3830..e72c13e5 100644 --- a/tests/api_tests/test_rewriter.cc +++ b/tests/api_tests/test_rewriter.cc @@ -39,8 +39,43 @@ using namespace cxx; +namespace { + +[[nodiscard]] auto make_substitution( + TranslationUnit* unit, TemplateDeclarationAST* templateDecl, + List* templateArgumentList) + -> std::vector { + auto control = unit->control(); + auto interp = ASTInterpreter{unit}; + + std::vector templateArguments; + + for (auto arg : ListView{templateArgumentList}) { + if (auto exprArg = ast_cast(arg)) { + auto expr = exprArg->expression; + // ### need to set scope and location + auto templArg = control->newVariableSymbol(nullptr, {}); + templArg->setInitializer(expr); + templArg->setType(control->add_const(expr->type)); + templArg->setConstValue(interp.evaluate(expr)); + if (!templArg->constValue().has_value()) + cxx_runtime_error("template argument is not a constant expression"); + templateArguments.push_back(templArg); + } else if (auto typeArg = ast_cast(arg)) { + auto type = typeArg->typeId->type; + // ### need to set scope and location + auto templArg = control->newTypeAliasSymbol(nullptr, {}); + templArg->setType(type); + templateArguments.push_back(templArg); + } + } + + return templateArguments; +} + template -auto subst(Source& source, Node* ast, std::vector args) { +[[nodiscard]] auto substitute(Source& source, Node* ast, + std::vector args) { auto control = source.control(); TypeChecker typeChecker(&source.unit); ASTRewriter rewrite{&typeChecker, args}; @@ -59,92 +94,50 @@ template return ast_cast(getTemplateBody(ast)); } +} // namespace + TEST(Rewriter, TypeAlias) { auto source = R"( -template -using Ptr = const T*; - template using Func = void(T, const U&); - )"_cxx; - - auto control = source.control(); - - auto ptrInstance = - subst(source, - getTemplateBodyAs( - source.getAs("Ptr")->templateDeclaration()), - {control->getIntType()}); - ASSERT_EQ(to_string(ptrInstance->typeId->type), "const int*"); - - auto funcInstance = - subst(source, - getTemplateBodyAs( - source.getAs("Func")->templateDeclaration()), - {control->getIntType(), control->getFloatType()}); - ASSERT_EQ(to_string(funcInstance->typeId->type), "void (int, const float&)"); -} - -TEST(Rewriter, Var) { - auto source = R"( -template -const int c = i + 321 + i; - -constexpr int x = 123 * 2; - -constexpr int y = c<123 * 2>; -)"_cxx; - - auto interp = ASTInterpreter{&source.unit}; +using Func1 = Func; + )"_cxx_no_templates; auto control = source.control(); - auto c = source.getAs("c"); - ASSERT_TRUE(c != nullptr); - auto templateDeclaration = c->templateDeclaration(); - ASSERT_TRUE(templateDeclaration != nullptr); - - // extract the expression 123 * 2 from the AST - auto x = source.getAs("x"); - ASSERT_TRUE(x != nullptr); - auto xinit = ast_cast(x->initializer())->expression; - ASSERT_TRUE(xinit != nullptr); - - // synthesize const auto i = 123 * 2; - - // ### need to set scope and location - auto templArg = control->newVariableSymbol(nullptr, {}); - templArg->setInitializer(xinit); - templArg->setType(control->add_const(x->type())); - templArg->setConstValue(interp.evaluate(xinit)); - ASSERT_TRUE(templArg->constValue().has_value()); + auto func1 = source.getAs("Func1"); + ASSERT_TRUE(func1 != nullptr); - auto instance = subst( - source, getTemplateBodyAs(templateDeclaration), - {templArg}); + std::cout << "Func1: " << to_string(func1->type()) << "\n"; - auto decl = instance->initDeclaratorList->value; - ASSERT_TRUE(decl != nullptr); + auto func1Type = type_cast(func1->type()); + ASSERT_TRUE(func1Type != nullptr); - auto init = ast_cast(decl->initializer); - ASSERT_TRUE(init); + auto templateId = ast_cast(func1Type->unqualifiedId()); + ASSERT_TRUE(templateId != nullptr); + auto templateSym = + symbol_cast(templateId->primaryTemplateSymbol); + ASSERT_TRUE(templateSym != nullptr); - auto value = interp.evaluate(init->expression); + auto templateArguments = + make_substitution(&source.unit, templateSym->templateDeclaration(), + templateId->templateArgumentList); - ASSERT_TRUE(value.has_value()); - - ASSERT_EQ(std::visit(ArithmeticCast{}, *value), 123 * 2 + 321 + 123 * 2); + auto funcInstance = substitute(source, + getTemplateBodyAs( + templateSym->templateDeclaration()), + templateArguments); + ASSERT_EQ(to_string(funcInstance->typeId->type), "void (int, const float&)"); } -// simulate a template-id instantiation -TEST(Rewriter, TemplateId) { +TEST(Rewriter, Var) { auto source = R"( template const int c = i + 321 + i; constexpr int y = c<123 * 2>; -)"_cxx; +)"_cxx_no_templates; auto interp = ASTInterpreter{&source.unit}; @@ -171,21 +164,11 @@ constexpr int y = c<123 * 2>; templateSym->templateDeclaration()); ASSERT_TRUE(templateDecl != nullptr); - std::vector templateArguments; - for (auto arg : ListView{templateId->templateArgumentList}) { - if (auto exprArg = ast_cast(arg)) { - auto expr = exprArg->expression; - // ### need to set scope and location - auto templArg = control->newVariableSymbol(nullptr, {}); - templArg->setInitializer(expr); - templArg->setType(control->add_const(expr->type)); - templArg->setConstValue(interp.evaluate(expr)); - ASSERT_TRUE(templArg->constValue().has_value()); - templateArguments.push_back(templArg); - } - } + std::vector templateArguments = + make_substitution(&source.unit, templateSym->templateDeclaration(), + templateId->templateArgumentList); - auto instance = subst(source, templateDecl, templateArguments); + auto instance = substitute(source, templateDecl, templateArguments); ASSERT_TRUE(instance != nullptr); auto decl = instance->initDeclaratorList->value; diff --git a/tests/api_tests/test_utils.h b/tests/api_tests/test_utils.h index 3a16a984..4ef6179c 100644 --- a/tests/api_tests/test_utils.h +++ b/tests/api_tests/test_utils.h @@ -29,6 +29,7 @@ #include #include +#include namespace cxx { @@ -42,11 +43,13 @@ struct Source { DiagnosticsClient diagnosticsClient; TranslationUnit unit{&diagnosticsClient}; - explicit Source(std::string_view source) { + explicit Source(std::string_view source, bool templateInstantiation = true) { unit.setSource(std::string(source), ""); + unit.parse({ .checkTypes = true, .fuzzyTemplateResolution = false, + .templateInstantiation = templateInstantiation, .reflect = true, }); } @@ -82,6 +85,14 @@ inline auto operator""_cxx(const char* source, std::size_t size) -> Source { return Source{std::string_view{source, size}}; } +inline auto operator""_cxx_no_templates(const char* source, std::size_t size) + -> Source { + // disable templates to allow overriding the template instantiation algorithm + bool templateInstantiation = false; + auto text = std::string_view{source, size}; + return Source{text, templateInstantiation}; +} + struct LookupMember { Source& source;