Skip to content

Commit 936ddbd

Browse files
committed
Clean up symbol instantiation
1 parent ed5f487 commit 936ddbd

File tree

8 files changed

+175
-163
lines changed

8 files changed

+175
-163
lines changed

src/parser/cxx/ast_rewriter.cc

Lines changed: 109 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,98 @@
3434

3535
namespace cxx {
3636

37+
namespace {
38+
struct GetTemplateDeclaration {
39+
auto operator()(ClassSymbol* symbol) -> TemplateDeclarationAST* {
40+
return symbol->templateDeclaration();
41+
}
42+
43+
auto operator()(VariableSymbol* symbol) -> TemplateDeclarationAST* {
44+
return symbol->templateDeclaration();
45+
}
46+
47+
auto operator()(TypeAliasSymbol* symbol) -> TemplateDeclarationAST* {
48+
return symbol->templateDeclaration();
49+
}
50+
51+
auto operator()(Symbol*) -> TemplateDeclarationAST* { return nullptr; }
52+
};
53+
54+
struct GetDeclaration {
55+
auto operator()(ClassSymbol* symbol) -> AST* { return symbol->declaration(); }
56+
57+
auto operator()(VariableSymbol* symbol) -> AST* {
58+
return symbol->templateDeclaration()->declaration;
59+
}
60+
61+
auto operator()(TypeAliasSymbol* symbol) -> AST* {
62+
return symbol->templateDeclaration()->declaration;
63+
}
64+
65+
auto operator()(Symbol*) -> AST* { return nullptr; }
66+
};
67+
68+
struct GetSpecialization {
69+
const std::vector<TemplateArgument>& templateArguments;
70+
71+
auto operator()(ClassSymbol* symbol) -> Symbol* {
72+
return symbol->findSpecialization(templateArguments);
73+
}
74+
75+
auto operator()(VariableSymbol* symbol) -> Symbol* {
76+
return symbol->findSpecialization(templateArguments);
77+
}
78+
79+
auto operator()(TypeAliasSymbol* symbol) -> Symbol* {
80+
return symbol->findSpecialization(templateArguments);
81+
}
82+
83+
auto operator()(Symbol*) -> Symbol* { return nullptr; }
84+
};
85+
86+
struct Instantiate {
87+
ASTRewriter& rewriter;
88+
89+
auto operator()(ClassSymbol* symbol) -> Symbol* {
90+
auto classSpecifier = ast_cast<ClassSpecifierAST>(symbol->declaration());
91+
if (!classSpecifier) return nullptr;
92+
93+
auto instance =
94+
ast_cast<ClassSpecifierAST>(rewriter.specifier(classSpecifier));
95+
96+
if (!instance) return nullptr;
97+
98+
return instance->symbol;
99+
}
100+
101+
auto operator()(VariableSymbol* symbol) -> Symbol* {
102+
auto declaration = symbol->templateDeclaration()->declaration;
103+
auto instance = ast_cast<SimpleDeclarationAST>(
104+
rewriter.declaration(ast_cast<SimpleDeclarationAST>(declaration)));
105+
106+
if (!instance) return nullptr;
107+
108+
auto instantiatedSymbol = instance->initDeclaratorList->value->symbol;
109+
auto instantiatedVariable = symbol_cast<VariableSymbol>(instantiatedSymbol);
110+
111+
return instantiatedVariable;
112+
}
113+
114+
auto operator()(TypeAliasSymbol* symbol) -> Symbol* {
115+
auto declaration = symbol->templateDeclaration()->declaration;
116+
117+
auto instance = ast_cast<AliasDeclarationAST>(
118+
rewriter.declaration(ast_cast<AliasDeclarationAST>(declaration)));
119+
120+
if (!instance) return nullptr;
121+
122+
return instance->symbol;
123+
}
124+
125+
auto operator()(Symbol*) -> Symbol* { return nullptr; }
126+
};
127+
} // namespace
128+
37129
ASTRewriter::ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
38130
const std::vector<TemplateArgument>& templateArguments)
39131
: unit_(unit), templateArguments_(templateArguments), binder_(unit_) {
@@ -91,69 +183,18 @@ auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* {
91183
return nullptr;
92184
}
93185

94-
auto ASTRewriter::instantiateClassTemplate(
95-
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
96-
ClassSymbol* classSymbol) -> ClassSymbol* {
97-
auto templateDecl = classSymbol->templateDeclaration();
98-
99-
if (!classSymbol->declaration()) return nullptr;
100-
101-
auto templateArguments =
102-
make_substitution(unit, templateDecl, templateArgumentList);
103-
104-
auto is_primary_template = [&]() -> bool {
105-
int expected = 0;
106-
for (const auto& arg : templateArguments) {
107-
if (!std::holds_alternative<Symbol*>(arg)) return false;
108-
109-
auto ty = type_cast<TypeParameterType>(std::get<Symbol*>(arg)->type());
110-
if (!ty) return false;
111-
112-
if (ty->index() != expected) return false;
113-
++expected;
114-
}
115-
return true;
116-
};
117-
118-
if (is_primary_template()) {
119-
// if this is a primary template, we can just return the class symbol
120-
return classSymbol;
121-
}
122-
123-
auto subst = classSymbol->findSpecialization(templateArguments);
124-
if (subst) {
125-
return subst;
126-
}
127-
128-
auto classSpecifier = ast_cast<ClassSpecifierAST>(classSymbol->declaration());
129-
if (!classSpecifier) return nullptr;
130-
131-
auto parentScope = classSymbol->enclosingNonTemplateParametersScope();
132-
133-
auto rewriter = ASTRewriter{unit, parentScope, templateArguments};
134-
rewriter.depth_ = templateDecl->depth;
135-
136-
rewriter.binder().setInstantiatingSymbol(classSymbol);
137-
138-
auto instance =
139-
ast_cast<ClassSpecifierAST>(rewriter.specifier(classSpecifier));
140-
141-
if (!instance) return nullptr;
142-
143-
auto classInstance = instance->symbol;
144-
145-
return classInstance;
146-
}
147-
148-
auto ASTRewriter::instantiateTypeAliasTemplate(
149-
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
150-
TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol* {
151-
auto templateDecl = typeAliasSymbol->templateDeclaration();
186+
auto ASTRewriter::instantiate(TranslationUnit* unit,
187+
List<TemplateArgumentAST*>* templateArgumentList,
188+
Symbol* symbol) -> Symbol* {
189+
auto classSymbol = symbol_cast<ClassSymbol>(symbol);
190+
auto variableSymbol = symbol_cast<VariableSymbol>(symbol);
191+
auto typeAliasSymbol = symbol_cast<TypeAliasSymbol>(symbol);
152192

153-
auto aliasDeclaration =
154-
ast_cast<AliasDeclarationAST>(templateDecl->declaration);
193+
auto templateDecl = visit(GetTemplateDeclaration{}, symbol);
194+
if (!templateDecl) return nullptr;
155195

156-
if (!aliasDeclaration) return nullptr;
196+
auto declaration = visit(GetDeclaration{}, symbol);
197+
if (!declaration) return nullptr;
157198

158199
auto templateArguments =
159200
make_substitution(unit, templateDecl, templateArgumentList);
@@ -174,93 +215,21 @@ auto ASTRewriter::instantiateTypeAliasTemplate(
174215

175216
if (is_primary_template()) {
176217
// if this is a primary template, we can just return the class symbol
177-
return typeAliasSymbol;
218+
return symbol;
178219
}
179220

180-
#if false
181-
auto subst = typeAliasSymbol->findSpecialization(templateArguments);
182-
if (subst) {
183-
return subst;
184-
}
185-
#endif
186-
187-
auto parentScope = typeAliasSymbol->parent();
188-
while (parentScope->isTemplateParameters()) {
189-
parentScope = parentScope->parent();
190-
}
191-
192-
auto rewriter = ASTRewriter{unit, parentScope, templateArguments};
193-
194-
rewriter.binder().setInstantiatingSymbol(typeAliasSymbol);
195-
196-
auto instance =
197-
ast_cast<AliasDeclarationAST>(rewriter.declaration(aliasDeclaration));
198-
199-
if (!instance) return nullptr;
200-
201-
return instance->symbol;
202-
}
203-
204-
auto ASTRewriter::instantiateVariableTemplate(
205-
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
206-
VariableSymbol* variableSymbol) -> VariableSymbol* {
207-
auto templateDecl = variableSymbol->templateDeclaration();
208-
209-
if (!templateDecl) {
210-
unit->error(variableSymbol->location(), "not a template");
211-
return nullptr;
212-
}
221+
auto specialization = visit(GetSpecialization{templateArguments}, symbol);
213222

214-
auto variableDeclaration =
215-
ast_cast<SimpleDeclarationAST>(templateDecl->declaration);
223+
if (specialization) return specialization;
216224

217-
if (!variableDeclaration) return nullptr;
218-
219-
auto templateArguments =
220-
make_substitution(unit, templateDecl, templateArgumentList);
221-
222-
auto is_primary_template = [&]() -> bool {
223-
int expected = 0;
224-
for (const auto& arg : templateArguments) {
225-
if (!std::holds_alternative<Symbol*>(arg)) return false;
226-
227-
auto ty = type_cast<TypeParameterType>(std::get<Symbol*>(arg)->type());
228-
if (!ty) return false;
229-
230-
if (ty->index() != expected) return false;
231-
++expected;
232-
}
233-
return true;
234-
};
235-
236-
if (is_primary_template()) {
237-
// if this is a primary template, we can just return the class symbol
238-
return variableSymbol;
239-
}
240-
241-
auto subst = variableSymbol->findSpecialization(templateArguments);
242-
if (subst) {
243-
return subst;
244-
}
245-
246-
auto parentScope = variableSymbol->parent();
247-
while (parentScope->isTemplateParameters()) {
248-
parentScope = parentScope->parent();
249-
}
225+
auto parentScope = symbol->enclosingNonTemplateParametersScope();
250226

251227
auto rewriter = ASTRewriter{unit, parentScope, templateArguments};
228+
rewriter.depth_ = templateDecl->depth;
252229

253-
rewriter.binder().setInstantiatingSymbol(variableSymbol);
254-
255-
auto instance =
256-
ast_cast<SimpleDeclarationAST>(rewriter.declaration(variableDeclaration));
257-
258-
if (!instance) return nullptr;
259-
260-
auto instantiatedSymbol = instance->initDeclaratorList->value->symbol;
261-
auto instantiatedVariable = symbol_cast<VariableSymbol>(instantiatedSymbol);
230+
rewriter.binder().setInstantiatingSymbol(symbol);
262231

263-
return instantiatedVariable;
232+
return visit(Instantiate{rewriter}, symbol);
264233
}
265234

266235
auto ASTRewriter::make_substitution(

src/parser/cxx/ast_rewriter.h

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,9 @@ class Arena;
3434

3535
class ASTRewriter {
3636
public:
37-
[[nodiscard]] static auto instantiateClassTemplate(
37+
[[nodiscard]] static auto instantiate(
3838
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
39-
ClassSymbol* symbol) -> ClassSymbol*;
40-
41-
[[nodiscard]] static auto instantiateTypeAliasTemplate(
42-
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
43-
TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol*;
44-
45-
[[nodiscard]] static auto instantiateVariableTemplate(
46-
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
47-
VariableSymbol* variableSymbol) -> VariableSymbol*;
39+
Symbol* symbol) -> Symbol*;
4840

4941
explicit ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
5042
const std::vector<TemplateArgument>& templateArguments);
@@ -58,6 +50,10 @@ class ASTRewriter {
5850
TemplateDeclarationAST* templateHead = nullptr)
5951
-> DeclarationAST*;
6052

53+
[[nodiscard]] auto specifier(SpecifierAST* ast,
54+
TemplateDeclarationAST* templateHead = nullptr)
55+
-> SpecifierAST*;
56+
6157
[[nodiscard]] static auto make_substitution(
6258
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
6359
List<TemplateArgumentAST*>* templateArgumentList)
@@ -91,9 +87,6 @@ class ASTRewriter {
9187
[[nodiscard]] auto designator(DesignatorAST* ast) -> DesignatorAST*;
9288
[[nodiscard]] auto templateParameter(TemplateParameterAST* ast)
9389
-> TemplateParameterAST*;
94-
[[nodiscard]] auto specifier(SpecifierAST* ast,
95-
TemplateDeclarationAST* templateHead = nullptr)
96-
-> SpecifierAST*;
9790
[[nodiscard]] auto ptrOperator(PtrOperatorAST* ast) -> PtrOperatorAST*;
9891
[[nodiscard]] auto coreDeclarator(CoreDeclaratorAST* ast)
9992
-> CoreDeclaratorAST*;

src/parser/cxx/ast_rewriter_declarations.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ auto ASTRewriter::DeclarationVisitor::operator()(AliasDeclarationAST* ast)
462462

463463
auto symbol = binder()->declareTypeAlias(copy->identifierLoc, copy->typeId,
464464
addSymbolToParentScope);
465+
if (!addSymbolToParentScope) {
466+
ast->symbol->addSpecialization(rewrite.templateArguments(), symbol);
467+
}
465468
// symbol->setTemplateDeclaration(templateHead);
466469

467470
copy->symbol = symbol;

src/parser/cxx/ast_rewriter_names.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,10 @@ auto ASTRewriter::NestedNameSpecifierVisitor::operator()(
305305

306306
auto classSymbol = symbol_cast<ClassSymbol>(copy->symbol);
307307

308-
auto instance = ASTRewriter::instantiateClassTemplate(
308+
auto instance = ASTRewriter::instantiate(
309309
translationUnit(), copy->templateId->templateArgumentList, classSymbol);
310310

311-
copy->symbol = instance;
311+
copy->symbol = symbol_cast<ClassSymbol>(instance);
312312

313313
return copy;
314314
}

src/parser/cxx/binder.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -914,15 +914,15 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier,
914914

915915
if (auto classSymbol = symbol_cast<ClassSymbol>(templateId->symbol)) {
916916
// todo: delay
917-
auto instance = ASTRewriter::instantiateClassTemplate(
917+
auto instance = ASTRewriter::instantiate(
918918
unit_, templateId->templateArgumentList, classSymbol);
919919

920920
return instance;
921921
}
922922

923923
if (auto typeAliasSymbol =
924924
symbol_cast<TypeAliasSymbol>(templateId->symbol)) {
925-
auto instance = ASTRewriter::instantiateTypeAliasTemplate(
925+
auto instance = ASTRewriter::instantiate(
926926
unit_, templateId->templateArgumentList, typeAliasSymbol);
927927

928928
return instance;
@@ -973,7 +973,7 @@ void Binder::bind(IdExpressionAST* ast) {
973973
if (!var) {
974974
error(templateId->firstSourceLocation(), std::format("not a template"));
975975
} else {
976-
auto instance = ASTRewriter::instantiateVariableTemplate(
976+
auto instance = ASTRewriter::instantiate(
977977
unit_, templateId->templateArgumentList, var);
978978

979979
ast->symbol = instance;

0 commit comments

Comments
 (0)