Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/parser/cxx/ast_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,68 @@ auto ASTRewriter::instantiateTypeAliasTemplate(
return instance->symbol;
}

auto ASTRewriter::instantiateVariableTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
VariableSymbol* variableSymbol) -> VariableSymbol* {
auto templateDecl = variableSymbol->templateDeclaration();

if (!templateDecl) {
unit->error(variableSymbol->location(), "not a template");
return nullptr;
}

auto variableDeclaration =
ast_cast<SimpleDeclarationAST>(templateDecl->declaration);

if (!variableDeclaration) return nullptr;

auto templateArguments =
make_substitution(unit, templateDecl, templateArgumentList);

auto is_primary_template = [&]() -> bool {
int expected = 0;
for (const auto& arg : templateArguments) {
if (!std::holds_alternative<Symbol*>(arg)) return false;

auto ty = type_cast<TypeParameterType>(std::get<Symbol*>(arg)->type());
if (!ty) return false;

if (ty->index() != expected) return false;
++expected;
}
return true;
};

if (is_primary_template()) {
// if this is a primary template, we can just return the class symbol
return variableSymbol;
}

auto subst = variableSymbol->findSpecialization(templateArguments);
if (subst) {
return subst;
}

auto parentScope = variableSymbol->parent();
while (parentScope->isTemplateParameters()) {
parentScope = parentScope->parent();
}

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};

rewriter.binder().setInstantiatingSymbol(variableSymbol);

auto instance =
ast_cast<SimpleDeclarationAST>(rewriter.declaration(variableDeclaration));

if (!instance) return nullptr;

auto instantiatedSymbol = instance->initDeclaratorList->value->symbol;
auto instantiatedVariable = symbol_cast<VariableSymbol>(instantiatedSymbol);

return instantiatedVariable;
}

auto ASTRewriter::make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
Expand Down
4 changes: 4 additions & 0 deletions src/parser/cxx/ast_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class ASTRewriter {
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol*;

[[nodiscard]] static auto instantiateVariableTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
VariableSymbol* variableSymbol) -> VariableSymbol*;

explicit ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
const std::vector<TemplateArgument>& templateArguments);
~ASTRewriter();
Expand Down
22 changes: 20 additions & 2 deletions src/parser/cxx/ast_rewriter_declarators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

// cxx
#include <cxx/ast.h>
#include <cxx/ast_interpreter.h>
#include <cxx/binder.h>
#include <cxx/decl.h>
#include <cxx/decl_specs.h>
Expand Down Expand Up @@ -208,6 +209,9 @@ auto ASTRewriter::initDeclarator(InitDeclaratorAST* ast,
auto type =
getDeclaratorType(translationUnit(), copy->declarator, declSpecs.type());

const auto addSymbolToParentScope =
binder().instantiatingSymbol() != ast->symbol;

// ### fix scope
if (binder_.scope()->isClass()) {
auto symbol = binder_.declareMemberSymbol(copy->declarator, decl);
Expand All @@ -222,16 +226,30 @@ auto ASTRewriter::initDeclarator(InitDeclaratorAST* ast,
auto functionSymbol = binder_.declareFunction(copy->declarator, decl);
copy->symbol = functionSymbol;
} else {
auto variableSymbol = binder_.declareVariable(copy->declarator, decl);
auto variableSymbol = binder_.declareVariable(copy->declarator, decl,
addSymbolToParentScope);
// variableSymbol->setTemplateDeclaration(templateHead);
copy->symbol = variableSymbol;

if (!addSymbolToParentScope) {
auto templateVariable = symbol_cast<VariableSymbol>(ast->symbol);
templateVariable->addSpecialization(templateArguments(),
variableSymbol);
}
}
}
}

copy->requiresClause = requiresClause(ast->requiresClause);
copy->initializer = expression(ast->initializer);
// copy->symbol = ast->symbol; // TODO remove, done above

if (auto variableSymbol = symbol_cast<VariableSymbol>(copy->symbol)) {
if (variableSymbol->isConstexpr()) {
auto interp = ASTInterpreter{unit_};
auto constValue = interp.evaluate(copy->initializer);
variableSymbol->setConstValue(constValue);
}
}

return copy;
}
Expand Down
23 changes: 20 additions & 3 deletions src/parser/cxx/binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -818,15 +818,17 @@ auto Binder::declareField(DeclaratorAST* declarator, const Decl& decl)
return fieldSymbol;
}

auto Binder::declareVariable(DeclaratorAST* declarator, const Decl& decl)
-> VariableSymbol* {
auto Binder::declareVariable(DeclaratorAST* declarator, const Decl& decl,
bool addSymbolToParentScope) -> VariableSymbol* {
auto name = decl.getName();
auto symbol = control()->newVariableSymbol(scope(), decl.location());
auto type = getDeclaratorType(unit_, declarator, decl.specs.type());
applySpecifiers(symbol, decl.specs);
symbol->setName(name);
symbol->setType(type);
declaringScope()->addSymbol(symbol);
if (addSymbolToParentScope) {
declaringScope()->addSymbol(symbol);
}
return symbol;
}

Expand Down Expand Up @@ -963,6 +965,21 @@ void Binder::bind(IdExpressionAST* ast) {
}

ast->symbol = Lookup{scope()}(ast->nestedNameSpecifier, componentName);

if (unit_->config().checkTypes) {
if (auto templateId = ast_cast<SimpleTemplateIdAST>(ast->unqualifiedId)) {
auto var = symbol_cast<VariableSymbol>(ast->symbol);

if (!var) {
error(templateId->firstSourceLocation(), std::format("not a template"));
} else {
auto instance = ASTRewriter::instantiateVariableTemplate(
unit_, templateId->templateArgumentList, var);

ast->symbol = instance;
}
}
}
}

auto Binder::getFunction(ScopeSymbol* scope, const Name* name, const Type* type)
Expand Down
4 changes: 3 additions & 1 deletion src/parser/cxx/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class Binder {
-> FieldSymbol*;

[[nodiscard]] auto declareVariable(DeclaratorAST* declarator,
const Decl& decl) -> VariableSymbol*;
const Decl& decl,
bool addSymbolToParentScope)
-> VariableSymbol*;

[[nodiscard]] auto declareMemberSymbol(DeclaratorAST* declarator,
const Decl& decl) -> Symbol*;
Expand Down
6 changes: 4 additions & 2 deletions src/parser/cxx/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3348,7 +3348,8 @@ void Parser::parse_condition(ExpressionAST*& yyast, const ExprContext& ctx) {

if (!parse_declarator(declarator, decl)) return false;

auto symbol = binder_.declareVariable(declarator, decl);
auto symbol = binder_.declareVariable(declarator, decl,
/*addSymbolToParentScope=*/true);

ExpressionAST* initializer = nullptr;

Expand Down Expand Up @@ -5573,7 +5574,8 @@ auto Parser::parse_init_declarator(InitDeclaratorAST*& yyast,
auto functionSymbol = binder_.declareFunction(declarator, decl);
symbol = functionSymbol;
} else {
auto variableSymbol = binder_.declareVariable(declarator, decl);
auto variableSymbol = binder_.declareVariable(
declarator, decl, /*addSymbolToParentScope=*/true);
variableSymbol->setTemplateDeclaration(templateHead);
symbol = variableSymbol;
}
Expand Down
26 changes: 25 additions & 1 deletion src/parser/cxx/symbol_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,35 @@ struct DumpSymbols {
if (symbol->isConstinit()) out << " constinit";
if (symbol->isInline()) out << " inline";

out << std::format(" {}\n", to_string(symbol->type(), symbol->name()));
out << std::format(" {}", to_string(symbol->type(), symbol->name()));

if (!symbol->templateArguments().empty()) {
out << "<";
std::string_view sep = "";
for (auto arg : symbol->templateArguments()) {
auto symbol = std::get_if<Symbol*>(&arg);
if (!symbol) continue;
auto sym = *symbol;
if (sym->isTypeAlias()) {
out << std::format("{}{}", sep, to_string(sym->type()));
} else if (auto var = symbol_cast<VariableSymbol>(sym)) {
auto cst = std::get<std::intmax_t>(var->constValue().value());
out << std::format("{}{}", sep, cst);
} else {
cxx_runtime_error("todo");
}
sep = ", ";
}
out << std::format(">");
}

out << "\n";

if (symbol->templateParameters()) {
dumpScope(symbol->templateParameters());
}

dumpSpecializations(symbol->specializations());
}

void operator()(FieldSymbol* symbol) {
Expand Down
22 changes: 22 additions & 0 deletions src/parser/cxx/symbols.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,28 @@ void VariableSymbol::setTemplateDeclaration(
templateDeclaration_ = declaration;
}

auto VariableSymbol::specializations() const
-> std::span<const TemplateSpecialization<VariableSymbol>> {
if (!templateInfo_) return {};
return templateInfo_->specializations();
}

auto VariableSymbol::findSpecialization(
const std::vector<TemplateArgument>& arguments) const -> VariableSymbol* {
if (!templateInfo_) return {};
return templateInfo_->findSpecialization(arguments);
}

void VariableSymbol::addSpecialization(std::vector<TemplateArgument> arguments,
VariableSymbol* specialization) {
if (!templateInfo_) {
templateInfo_ = std::make_unique<TemplateInfo<VariableSymbol>>(this);
}
auto index = templateInfo_->specializations().size();
specialization->setSpecializationInfo(this, index);
templateInfo_->addSpecialization(std::move(arguments), specialization);
}

auto VariableSymbol::initializer() const -> ExpressionAST* {
return initializer_;
}
Expand Down
25 changes: 25 additions & 0 deletions src/parser/cxx/symbols.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,28 @@ class VariableSymbol final : public Symbol {
[[nodiscard]] auto templateDeclaration() const -> TemplateDeclarationAST*;
void setTemplateDeclaration(TemplateDeclarationAST* declaration);

[[nodiscard]] auto specializations() const
-> std::span<const TemplateSpecialization<VariableSymbol>>;

[[nodiscard]] auto findSpecialization(
const std::vector<TemplateArgument>& arguments) const -> VariableSymbol*;

void addSpecialization(std::vector<TemplateArgument> arguments,
VariableSymbol* specialization);

void setSpecializationInfo(VariableSymbol* templateVariable,
std::size_t index) {
templateVariable_ = templateVariable;
templateSepcializationIndex_ = index;
}

[[nodiscard]] auto templateArguments() const
-> std::span<const TemplateArgument> {
if (!templateVariable_) return {};
return templateVariable_->specializations()[templateSepcializationIndex_]
.arguments;
}

[[nodiscard]] auto initializer() const -> ExpressionAST*;
void setInitializer(ExpressionAST*);

Expand All @@ -605,6 +627,9 @@ class VariableSymbol final : public Symbol {
TemplateDeclarationAST* templateDeclaration_ = nullptr;
ExpressionAST* initializer_ = nullptr;
std::optional<ConstValue> constValue_;
std::unique_ptr<TemplateInfo<VariableSymbol>> templateInfo_;
VariableSymbol* templateVariable_ = nullptr;
std::size_t templateSepcializationIndex_ = 0;

union {
std::uint32_t flags_{};
Expand Down
17 changes: 17 additions & 0 deletions tests/unit_tests/sema/template_var_01.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: %cxx -verify -fcheck -dump-symbols %s

template <typename T>
constexpr bool is_integral_v = __is_integral(T);

static_assert(is_integral_v<int>);
static_assert(is_integral_v<char>);
static_assert(is_integral_v<void*> == false);

// clang-format off
// CHECK:namespace
// CHECK-NEXT: template variable constexpr const bool is_integral_v
// CHECK-NEXT: parameter typename<0, 0> T
// CHECK-NEXT: [specializations]
// CHECK-NEXT: variable constexpr bool is_integral_v<int>
// CHECK-NEXT: variable constexpr bool is_integral_v<char>
// CHECK-NEXT: variable constexpr bool is_integral_v<void*>
Loading