Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 109 additions & 140 deletions src/parser/cxx/ast_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,98 @@

namespace cxx {

namespace {
struct GetTemplateDeclaration {
auto operator()(ClassSymbol* symbol) -> TemplateDeclarationAST* {
return symbol->templateDeclaration();
}

auto operator()(VariableSymbol* symbol) -> TemplateDeclarationAST* {
return symbol->templateDeclaration();
}

auto operator()(TypeAliasSymbol* symbol) -> TemplateDeclarationAST* {
return symbol->templateDeclaration();
}

auto operator()(Symbol*) -> TemplateDeclarationAST* { return nullptr; }
};

struct GetDeclaration {
auto operator()(ClassSymbol* symbol) -> AST* { return symbol->declaration(); }

auto operator()(VariableSymbol* symbol) -> AST* {
return symbol->templateDeclaration()->declaration;
}

auto operator()(TypeAliasSymbol* symbol) -> AST* {
return symbol->templateDeclaration()->declaration;
}

auto operator()(Symbol*) -> AST* { return nullptr; }
};

struct GetSpecialization {
const std::vector<TemplateArgument>& templateArguments;

auto operator()(ClassSymbol* symbol) -> Symbol* {
return symbol->findSpecialization(templateArguments);
}

auto operator()(VariableSymbol* symbol) -> Symbol* {
return symbol->findSpecialization(templateArguments);
}

auto operator()(TypeAliasSymbol* symbol) -> Symbol* {
return symbol->findSpecialization(templateArguments);
}

auto operator()(Symbol*) -> Symbol* { return nullptr; }
};

struct Instantiate {
ASTRewriter& rewriter;

auto operator()(ClassSymbol* symbol) -> Symbol* {
auto classSpecifier = ast_cast<ClassSpecifierAST>(symbol->declaration());
if (!classSpecifier) return nullptr;

auto instance =
ast_cast<ClassSpecifierAST>(rewriter.specifier(classSpecifier));

if (!instance) return nullptr;

return instance->symbol;
}

auto operator()(VariableSymbol* symbol) -> Symbol* {
auto declaration = symbol->templateDeclaration()->declaration;
auto instance = ast_cast<SimpleDeclarationAST>(
rewriter.declaration(ast_cast<SimpleDeclarationAST>(declaration)));

if (!instance) return nullptr;

auto instantiatedSymbol = instance->initDeclaratorList->value->symbol;
auto instantiatedVariable = symbol_cast<VariableSymbol>(instantiatedSymbol);

return instantiatedVariable;
}

auto operator()(TypeAliasSymbol* symbol) -> Symbol* {
auto declaration = symbol->templateDeclaration()->declaration;

auto instance = ast_cast<AliasDeclarationAST>(
rewriter.declaration(ast_cast<AliasDeclarationAST>(declaration)));

if (!instance) return nullptr;

return instance->symbol;
}

auto operator()(Symbol*) -> Symbol* { return nullptr; }
};
} // namespace

ASTRewriter::ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
const std::vector<TemplateArgument>& templateArguments)
: unit_(unit), templateArguments_(templateArguments), binder_(unit_) {
Expand Down Expand Up @@ -91,69 +183,18 @@ auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* {
return nullptr;
}

auto ASTRewriter::instantiateClassTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
ClassSymbol* classSymbol) -> ClassSymbol* {
auto templateDecl = classSymbol->templateDeclaration();

if (!classSymbol->declaration()) return nullptr;

auto templateArguments =
make_substitution(unit, templateDecl, templateArgumentList);

auto is_primary_template = [&]() -> bool {
int expected = 0;
for (const auto& arg : templateArguments) {
if (!std::holds_alternative<Symbol*>(arg)) return false;

auto ty = type_cast<TypeParameterType>(std::get<Symbol*>(arg)->type());
if (!ty) return false;

if (ty->index() != expected) return false;
++expected;
}
return true;
};

if (is_primary_template()) {
// if this is a primary template, we can just return the class symbol
return classSymbol;
}

auto subst = classSymbol->findSpecialization(templateArguments);
if (subst) {
return subst;
}

auto classSpecifier = ast_cast<ClassSpecifierAST>(classSymbol->declaration());
if (!classSpecifier) return nullptr;

auto parentScope = classSymbol->enclosingNonTemplateParametersScope();

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};
rewriter.depth_ = templateDecl->depth;

rewriter.binder().setInstantiatingSymbol(classSymbol);

auto instance =
ast_cast<ClassSpecifierAST>(rewriter.specifier(classSpecifier));

if (!instance) return nullptr;

auto classInstance = instance->symbol;

return classInstance;
}

auto ASTRewriter::instantiateTypeAliasTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol* {
auto templateDecl = typeAliasSymbol->templateDeclaration();
auto ASTRewriter::instantiate(TranslationUnit* unit,
List<TemplateArgumentAST*>* templateArgumentList,
Symbol* symbol) -> Symbol* {
auto classSymbol = symbol_cast<ClassSymbol>(symbol);
auto variableSymbol = symbol_cast<VariableSymbol>(symbol);
auto typeAliasSymbol = symbol_cast<TypeAliasSymbol>(symbol);

auto aliasDeclaration =
ast_cast<AliasDeclarationAST>(templateDecl->declaration);
auto templateDecl = visit(GetTemplateDeclaration{}, symbol);
if (!templateDecl) return nullptr;

if (!aliasDeclaration) return nullptr;
auto declaration = visit(GetDeclaration{}, symbol);
if (!declaration) return nullptr;

auto templateArguments =
make_substitution(unit, templateDecl, templateArgumentList);
Expand All @@ -174,93 +215,21 @@ auto ASTRewriter::instantiateTypeAliasTemplate(

if (is_primary_template()) {
// if this is a primary template, we can just return the class symbol
return typeAliasSymbol;
return symbol;
}

#if false
auto subst = typeAliasSymbol->findSpecialization(templateArguments);
if (subst) {
return subst;
}
#endif

auto parentScope = typeAliasSymbol->parent();
while (parentScope->isTemplateParameters()) {
parentScope = parentScope->parent();
}

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};

rewriter.binder().setInstantiatingSymbol(typeAliasSymbol);

auto instance =
ast_cast<AliasDeclarationAST>(rewriter.declaration(aliasDeclaration));

if (!instance) return nullptr;

return instance->symbol;
}

auto ASTRewriter::instantiateVariableTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
VariableSymbol* variableSymbol) -> VariableSymbol* {
auto templateDecl = variableSymbol->templateDeclaration();

if (!templateDecl) {
unit->error(variableSymbol->location(), "not a template");
return nullptr;
}
auto specialization = visit(GetSpecialization{templateArguments}, symbol);

auto variableDeclaration =
ast_cast<SimpleDeclarationAST>(templateDecl->declaration);
if (specialization) return specialization;

if (!variableDeclaration) return nullptr;

auto templateArguments =
make_substitution(unit, templateDecl, templateArgumentList);

auto is_primary_template = [&]() -> bool {
int expected = 0;
for (const auto& arg : templateArguments) {
if (!std::holds_alternative<Symbol*>(arg)) return false;

auto ty = type_cast<TypeParameterType>(std::get<Symbol*>(arg)->type());
if (!ty) return false;

if (ty->index() != expected) return false;
++expected;
}
return true;
};

if (is_primary_template()) {
// if this is a primary template, we can just return the class symbol
return variableSymbol;
}

auto subst = variableSymbol->findSpecialization(templateArguments);
if (subst) {
return subst;
}

auto parentScope = variableSymbol->parent();
while (parentScope->isTemplateParameters()) {
parentScope = parentScope->parent();
}
auto parentScope = symbol->enclosingNonTemplateParametersScope();

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};
rewriter.depth_ = templateDecl->depth;

rewriter.binder().setInstantiatingSymbol(variableSymbol);

auto instance =
ast_cast<SimpleDeclarationAST>(rewriter.declaration(variableDeclaration));

if (!instance) return nullptr;

auto instantiatedSymbol = instance->initDeclaratorList->value->symbol;
auto instantiatedVariable = symbol_cast<VariableSymbol>(instantiatedSymbol);
rewriter.binder().setInstantiatingSymbol(symbol);

return instantiatedVariable;
return visit(Instantiate{rewriter}, symbol);
}

auto ASTRewriter::make_substitution(
Expand Down
19 changes: 6 additions & 13 deletions src/parser/cxx/ast_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,9 @@ class Arena;

class ASTRewriter {
public:
[[nodiscard]] static auto instantiateClassTemplate(
[[nodiscard]] static auto instantiate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
ClassSymbol* symbol) -> ClassSymbol*;

[[nodiscard]] static auto instantiateTypeAliasTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol*;

[[nodiscard]] static auto instantiateVariableTemplate(
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
VariableSymbol* variableSymbol) -> VariableSymbol*;
Symbol* symbol) -> Symbol*;

explicit ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
const std::vector<TemplateArgument>& templateArguments);
Expand All @@ -58,6 +50,10 @@ class ASTRewriter {
TemplateDeclarationAST* templateHead = nullptr)
-> DeclarationAST*;

[[nodiscard]] auto specifier(SpecifierAST* ast,
TemplateDeclarationAST* templateHead = nullptr)
-> SpecifierAST*;

[[nodiscard]] static auto make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
Expand Down Expand Up @@ -91,9 +87,6 @@ class ASTRewriter {
[[nodiscard]] auto designator(DesignatorAST* ast) -> DesignatorAST*;
[[nodiscard]] auto templateParameter(TemplateParameterAST* ast)
-> TemplateParameterAST*;
[[nodiscard]] auto specifier(SpecifierAST* ast,
TemplateDeclarationAST* templateHead = nullptr)
-> SpecifierAST*;
[[nodiscard]] auto ptrOperator(PtrOperatorAST* ast) -> PtrOperatorAST*;
[[nodiscard]] auto coreDeclarator(CoreDeclaratorAST* ast)
-> CoreDeclaratorAST*;
Expand Down
3 changes: 3 additions & 0 deletions src/parser/cxx/ast_rewriter_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ auto ASTRewriter::DeclarationVisitor::operator()(AliasDeclarationAST* ast)

auto symbol = binder()->declareTypeAlias(copy->identifierLoc, copy->typeId,
addSymbolToParentScope);
if (!addSymbolToParentScope) {
ast->symbol->addSpecialization(rewrite.templateArguments(), symbol);
}
// symbol->setTemplateDeclaration(templateHead);

copy->symbol = symbol;
Expand Down
4 changes: 2 additions & 2 deletions src/parser/cxx/ast_rewriter_names.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ auto ASTRewriter::NestedNameSpecifierVisitor::operator()(

auto classSymbol = symbol_cast<ClassSymbol>(copy->symbol);

auto instance = ASTRewriter::instantiateClassTemplate(
auto instance = ASTRewriter::instantiate(
translationUnit(), copy->templateId->templateArgumentList, classSymbol);

copy->symbol = instance;
copy->symbol = symbol_cast<ClassSymbol>(instance);

return copy;
}
Expand Down
6 changes: 3 additions & 3 deletions src/parser/cxx/binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,15 +914,15 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier,

if (auto classSymbol = symbol_cast<ClassSymbol>(templateId->symbol)) {
// todo: delay
auto instance = ASTRewriter::instantiateClassTemplate(
auto instance = ASTRewriter::instantiate(
unit_, templateId->templateArgumentList, classSymbol);

return instance;
}

if (auto typeAliasSymbol =
symbol_cast<TypeAliasSymbol>(templateId->symbol)) {
auto instance = ASTRewriter::instantiateTypeAliasTemplate(
auto instance = ASTRewriter::instantiate(
unit_, templateId->templateArgumentList, typeAliasSymbol);

return instance;
Expand Down Expand Up @@ -973,7 +973,7 @@ void Binder::bind(IdExpressionAST* ast) {
if (!var) {
error(templateId->firstSourceLocation(), std::format("not a template"));
} else {
auto instance = ASTRewriter::instantiateVariableTemplate(
auto instance = ASTRewriter::instantiate(
unit_, templateId->templateArgumentList, var);

ast->symbol = instance;
Expand Down
Loading
Loading