diff --git a/src/parser/cxx/ast_rewriter.cc b/src/parser/cxx/ast_rewriter.cc index dc9381ee..a441ef53 100644 --- a/src/parser/cxx/ast_rewriter.cc +++ b/src/parser/cxx/ast_rewriter.cc @@ -201,6 +201,68 @@ auto ASTRewriter::instantiateTypeAliasTemplate( return instance->symbol; } +auto ASTRewriter::instantiateVariableTemplate( + TranslationUnit* unit, List* templateArgumentList, + VariableSymbol* variableSymbol) -> VariableSymbol* { + auto templateDecl = variableSymbol->templateDeclaration(); + + if (!templateDecl) { + unit->error(variableSymbol->location(), "not a template"); + return nullptr; + } + + auto variableDeclaration = + ast_cast(templateDecl->declaration); + + if (!variableDeclaration) return nullptr; + + auto templateArguments = + make_substitution(unit, templateDecl, templateArgumentList); + + auto is_primary_template = [&]() -> bool { + int expected = 0; + for (const auto& arg : templateArguments) { + if (!std::holds_alternative(arg)) return false; + + auto ty = type_cast(std::get(arg)->type()); + if (!ty) return false; + + if (ty->index() != expected) return false; + ++expected; + } + return true; + }; + + if (is_primary_template()) { + // if this is a primary template, we can just return the class symbol + return variableSymbol; + } + + auto subst = variableSymbol->findSpecialization(templateArguments); + if (subst) { + return subst; + } + + auto parentScope = variableSymbol->parent(); + while (parentScope->isTemplateParameters()) { + parentScope = parentScope->parent(); + } + + auto rewriter = ASTRewriter{unit, parentScope, templateArguments}; + + rewriter.binder().setInstantiatingSymbol(variableSymbol); + + auto instance = + ast_cast(rewriter.declaration(variableDeclaration)); + + if (!instance) return nullptr; + + auto instantiatedSymbol = instance->initDeclaratorList->value->symbol; + auto instantiatedVariable = symbol_cast(instantiatedSymbol); + + return instantiatedVariable; +} + auto ASTRewriter::make_substitution( TranslationUnit* unit, TemplateDeclarationAST* templateDecl, List* templateArgumentList) diff --git a/src/parser/cxx/ast_rewriter.h b/src/parser/cxx/ast_rewriter.h index 0e8f7a84..99181236 100644 --- a/src/parser/cxx/ast_rewriter.h +++ b/src/parser/cxx/ast_rewriter.h @@ -42,6 +42,10 @@ class ASTRewriter { TranslationUnit* unit, List* templateArgumentList, TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol*; + [[nodiscard]] static auto instantiateVariableTemplate( + TranslationUnit* unit, List* templateArgumentList, + VariableSymbol* variableSymbol) -> VariableSymbol*; + explicit ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope, const std::vector& templateArguments); ~ASTRewriter(); diff --git a/src/parser/cxx/ast_rewriter_declarators.cc b/src/parser/cxx/ast_rewriter_declarators.cc index 5daf79a6..3d00d4aa 100644 --- a/src/parser/cxx/ast_rewriter_declarators.cc +++ b/src/parser/cxx/ast_rewriter_declarators.cc @@ -22,6 +22,7 @@ // cxx #include +#include #include #include #include @@ -208,6 +209,9 @@ auto ASTRewriter::initDeclarator(InitDeclaratorAST* ast, auto type = getDeclaratorType(translationUnit(), copy->declarator, declSpecs.type()); + const auto addSymbolToParentScope = + binder().instantiatingSymbol() != ast->symbol; + // ### fix scope if (binder_.scope()->isClass()) { auto symbol = binder_.declareMemberSymbol(copy->declarator, decl); @@ -222,16 +226,30 @@ auto ASTRewriter::initDeclarator(InitDeclaratorAST* ast, auto functionSymbol = binder_.declareFunction(copy->declarator, decl); copy->symbol = functionSymbol; } else { - auto variableSymbol = binder_.declareVariable(copy->declarator, decl); + auto variableSymbol = binder_.declareVariable(copy->declarator, decl, + addSymbolToParentScope); // variableSymbol->setTemplateDeclaration(templateHead); copy->symbol = variableSymbol; + + if (!addSymbolToParentScope) { + auto templateVariable = symbol_cast(ast->symbol); + templateVariable->addSpecialization(templateArguments(), + variableSymbol); + } } } } copy->requiresClause = requiresClause(ast->requiresClause); copy->initializer = expression(ast->initializer); - // copy->symbol = ast->symbol; // TODO remove, done above + + if (auto variableSymbol = symbol_cast(copy->symbol)) { + if (variableSymbol->isConstexpr()) { + auto interp = ASTInterpreter{unit_}; + auto constValue = interp.evaluate(copy->initializer); + variableSymbol->setConstValue(constValue); + } + } return copy; } diff --git a/src/parser/cxx/binder.cc b/src/parser/cxx/binder.cc index faca18d3..d80e1dcf 100644 --- a/src/parser/cxx/binder.cc +++ b/src/parser/cxx/binder.cc @@ -818,15 +818,17 @@ auto Binder::declareField(DeclaratorAST* declarator, const Decl& decl) return fieldSymbol; } -auto Binder::declareVariable(DeclaratorAST* declarator, const Decl& decl) - -> VariableSymbol* { +auto Binder::declareVariable(DeclaratorAST* declarator, const Decl& decl, + bool addSymbolToParentScope) -> VariableSymbol* { auto name = decl.getName(); auto symbol = control()->newVariableSymbol(scope(), decl.location()); auto type = getDeclaratorType(unit_, declarator, decl.specs.type()); applySpecifiers(symbol, decl.specs); symbol->setName(name); symbol->setType(type); - declaringScope()->addSymbol(symbol); + if (addSymbolToParentScope) { + declaringScope()->addSymbol(symbol); + } return symbol; } @@ -963,6 +965,21 @@ void Binder::bind(IdExpressionAST* ast) { } ast->symbol = Lookup{scope()}(ast->nestedNameSpecifier, componentName); + + if (unit_->config().checkTypes) { + if (auto templateId = ast_cast(ast->unqualifiedId)) { + auto var = symbol_cast(ast->symbol); + + if (!var) { + error(templateId->firstSourceLocation(), std::format("not a template")); + } else { + auto instance = ASTRewriter::instantiateVariableTemplate( + unit_, templateId->templateArgumentList, var); + + ast->symbol = instance; + } + } + } } auto Binder::getFunction(ScopeSymbol* scope, const Name* name, const Type* type) diff --git a/src/parser/cxx/binder.h b/src/parser/cxx/binder.h index f648c8f9..56f83475 100644 --- a/src/parser/cxx/binder.h +++ b/src/parser/cxx/binder.h @@ -82,7 +82,9 @@ class Binder { -> FieldSymbol*; [[nodiscard]] auto declareVariable(DeclaratorAST* declarator, - const Decl& decl) -> VariableSymbol*; + const Decl& decl, + bool addSymbolToParentScope) + -> VariableSymbol*; [[nodiscard]] auto declareMemberSymbol(DeclaratorAST* declarator, const Decl& decl) -> Symbol*; diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index bb294b9e..98658c00 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -3348,7 +3348,8 @@ void Parser::parse_condition(ExpressionAST*& yyast, const ExprContext& ctx) { if (!parse_declarator(declarator, decl)) return false; - auto symbol = binder_.declareVariable(declarator, decl); + auto symbol = binder_.declareVariable(declarator, decl, + /*addSymbolToParentScope=*/true); ExpressionAST* initializer = nullptr; @@ -5573,7 +5574,8 @@ auto Parser::parse_init_declarator(InitDeclaratorAST*& yyast, auto functionSymbol = binder_.declareFunction(declarator, decl); symbol = functionSymbol; } else { - auto variableSymbol = binder_.declareVariable(declarator, decl); + auto variableSymbol = binder_.declareVariable( + declarator, decl, /*addSymbolToParentScope=*/true); variableSymbol->setTemplateDeclaration(templateHead); symbol = variableSymbol; } diff --git a/src/parser/cxx/symbol_printer.cc b/src/parser/cxx/symbol_printer.cc index 1556846e..b9461d4c 100644 --- a/src/parser/cxx/symbol_printer.cc +++ b/src/parser/cxx/symbol_printer.cc @@ -288,11 +288,35 @@ struct DumpSymbols { if (symbol->isConstinit()) out << " constinit"; if (symbol->isInline()) out << " inline"; - out << std::format(" {}\n", to_string(symbol->type(), symbol->name())); + out << std::format(" {}", to_string(symbol->type(), symbol->name())); + + if (!symbol->templateArguments().empty()) { + out << "<"; + std::string_view sep = ""; + for (auto arg : symbol->templateArguments()) { + auto symbol = std::get_if(&arg); + if (!symbol) continue; + auto sym = *symbol; + if (sym->isTypeAlias()) { + out << std::format("{}{}", sep, to_string(sym->type())); + } else if (auto var = symbol_cast(sym)) { + auto cst = std::get(var->constValue().value()); + out << std::format("{}{}", sep, cst); + } else { + cxx_runtime_error("todo"); + } + sep = ", "; + } + out << std::format(">"); + } + + out << "\n"; if (symbol->templateParameters()) { dumpScope(symbol->templateParameters()); } + + dumpSpecializations(symbol->specializations()); } void operator()(FieldSymbol* symbol) { diff --git a/src/parser/cxx/symbols.cc b/src/parser/cxx/symbols.cc index 93d0082f..e788837b 100644 --- a/src/parser/cxx/symbols.cc +++ b/src/parser/cxx/symbols.cc @@ -769,6 +769,28 @@ void VariableSymbol::setTemplateDeclaration( templateDeclaration_ = declaration; } +auto VariableSymbol::specializations() const + -> std::span> { + if (!templateInfo_) return {}; + return templateInfo_->specializations(); +} + +auto VariableSymbol::findSpecialization( + const std::vector& arguments) const -> VariableSymbol* { + if (!templateInfo_) return {}; + return templateInfo_->findSpecialization(arguments); +} + +void VariableSymbol::addSpecialization(std::vector arguments, + VariableSymbol* specialization) { + if (!templateInfo_) { + templateInfo_ = std::make_unique>(this); + } + auto index = templateInfo_->specializations().size(); + specialization->setSpecializationInfo(this, index); + templateInfo_->addSpecialization(std::move(arguments), specialization); +} + auto VariableSymbol::initializer() const -> ExpressionAST* { return initializer_; } diff --git a/src/parser/cxx/symbols.h b/src/parser/cxx/symbols.h index 703a3efb..0cdd63ea 100644 --- a/src/parser/cxx/symbols.h +++ b/src/parser/cxx/symbols.h @@ -595,6 +595,28 @@ class VariableSymbol final : public Symbol { [[nodiscard]] auto templateDeclaration() const -> TemplateDeclarationAST*; void setTemplateDeclaration(TemplateDeclarationAST* declaration); + [[nodiscard]] auto specializations() const + -> std::span>; + + [[nodiscard]] auto findSpecialization( + const std::vector& arguments) const -> VariableSymbol*; + + void addSpecialization(std::vector arguments, + VariableSymbol* specialization); + + void setSpecializationInfo(VariableSymbol* templateVariable, + std::size_t index) { + templateVariable_ = templateVariable; + templateSepcializationIndex_ = index; + } + + [[nodiscard]] auto templateArguments() const + -> std::span { + if (!templateVariable_) return {}; + return templateVariable_->specializations()[templateSepcializationIndex_] + .arguments; + } + [[nodiscard]] auto initializer() const -> ExpressionAST*; void setInitializer(ExpressionAST*); @@ -605,6 +627,9 @@ class VariableSymbol final : public Symbol { TemplateDeclarationAST* templateDeclaration_ = nullptr; ExpressionAST* initializer_ = nullptr; std::optional constValue_; + std::unique_ptr> templateInfo_; + VariableSymbol* templateVariable_ = nullptr; + std::size_t templateSepcializationIndex_ = 0; union { std::uint32_t flags_{}; diff --git a/tests/unit_tests/sema/template_var_01.cc b/tests/unit_tests/sema/template_var_01.cc new file mode 100644 index 00000000..31ad337c --- /dev/null +++ b/tests/unit_tests/sema/template_var_01.cc @@ -0,0 +1,17 @@ +// RUN: %cxx -verify -fcheck -dump-symbols %s + +template +constexpr bool is_integral_v = __is_integral(T); + +static_assert(is_integral_v); +static_assert(is_integral_v); +static_assert(is_integral_v == false); + +// clang-format off +// CHECK:namespace +// CHECK-NEXT: template variable constexpr const bool is_integral_v +// CHECK-NEXT: parameter typename<0, 0> T +// CHECK-NEXT: [specializations] +// CHECK-NEXT: variable constexpr bool is_integral_v +// CHECK-NEXT: variable constexpr bool is_integral_v +// CHECK-NEXT: variable constexpr bool is_integral_v