Skip to content

Commit 86207b1

Browse files
committed
Improve binding of template class declarations
1 parent 22e68c7 commit 86207b1

File tree

8 files changed

+162
-216
lines changed

8 files changed

+162
-216
lines changed

src/parser/cxx/ast_rewriter.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ auto ASTRewriter::instantiateClassTemplate(
9696
ClassSymbol* classSymbol) -> ClassSymbol* {
9797
auto templateDecl = classSymbol->templateDeclaration();
9898

99-
ClassSpecifierAST* classSpecifier =
100-
ast_cast<ClassSpecifierAST>(classSymbol->declaration());
101-
102-
if (!classSpecifier) return nullptr;
99+
if (!classSymbol->declaration()) return nullptr;
103100

104101
auto templateArguments =
105102
make_substitution(unit, templateDecl, templateArgumentList);
@@ -128,6 +125,9 @@ auto ASTRewriter::instantiateClassTemplate(
128125
return subst;
129126
}
130127

128+
auto classSpecifier = ast_cast<ClassSpecifierAST>(classSymbol->declaration());
129+
if (!classSpecifier) return nullptr;
130+
131131
auto parentScope = classSymbol->enclosingNonTemplateParametersScope();
132132

133133
auto rewriter = ASTRewriter{unit, parentScope, templateArguments};

src/parser/cxx/ast_rewriter.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ class ASTRewriter {
5454
TemplateDeclarationAST* templateHead = nullptr)
5555
-> DeclarationAST*;
5656

57+
[[nodiscard]] static auto make_substitution(
58+
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
59+
List<TemplateArgumentAST*>* templateArgumentList)
60+
-> std::vector<TemplateArgument>;
61+
5762
private:
5863
[[nodiscard]] auto templateArguments() const
5964
-> const std::vector<TemplateArgument>& {
@@ -73,11 +78,6 @@ class ASTRewriter {
7378
[[nodiscard]] auto restrictedToDeclarations() const -> bool;
7479
void setRestrictedToDeclarations(bool restrictedToDeclarations);
7580

76-
[[nodiscard]] static auto make_substitution(
77-
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
78-
List<TemplateArgumentAST*>* templateArgumentList)
79-
-> std::vector<TemplateArgument>;
80-
8181
// run on the base nodes
8282
[[nodiscard]] auto unit(UnitAST* ast) -> UnitAST*;
8383
[[nodiscard]] auto statement(StatementAST* ast) -> StatementAST*;

src/parser/cxx/ast_rewriter_specifiers.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,10 +787,6 @@ auto ASTRewriter::SpecifierVisitor::operator()(ClassSpecifierAST* ast)
787787
classSymbol->setDeclaration(copy);
788788
classSymbol->setTemplateDeclaration(templateHead);
789789

790-
if (templateHead) {
791-
classSymbol->setTemplateParameters(binder()->currentTemplateParameters());
792-
}
793-
794790
if (ast->symbol == rewrite.binder().instantiatingSymbol()) {
795791
ast->symbol->addSpecialization(rewrite.templateArguments(), classSymbol);
796792
} else {

src/parser/cxx/binder.cc

Lines changed: 112 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void Binder::setScope(ScopeSymbol* scope) {
108108
inTemplate_ = false;
109109

110110
for (auto current = scope_; current; current = current->parent()) {
111-
if (current->isTemplateParameters()) {
111+
if (auto params = current->templateParameters()) {
112112
inTemplate_ = true;
113113
break;
114114
}
@@ -212,7 +212,6 @@ void Binder::bind(ElaboratedTypeSpecifierAST* ast, DeclSpecs& declSpecs,
212212

213213
classSymbol->setIsUnion(isUnion);
214214
classSymbol->setName(name);
215-
classSymbol->setTemplateParameters(currentTemplateParameters());
216215
classSymbol->setTemplateDeclaration(declSpecs.templateHead);
217216
declaringScope()->addSymbol(classSymbol);
218217

@@ -230,42 +229,99 @@ void Binder::bind(ElaboratedTypeSpecifierAST* ast, DeclSpecs& declSpecs,
230229
}
231230

232231
void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
233-
auto templateParameters = currentTemplateParameters();
232+
auto check_optional_nested_name_specifier = [&] {
233+
if (!ast->nestedNameSpecifier) return;
234234

235-
if (ast->nestedNameSpecifier) {
236235
auto parent = ast->nestedNameSpecifier->symbol;
237236

238-
if (parent && parent->isClassOrNamespace()) {
239-
setScope(static_cast<ScopeSymbol*>(parent));
237+
if (!parent || !parent->isClassOrNamespace()) {
238+
error(ast->nestedNameSpecifier->firstSourceLocation(),
239+
"nested name specifier must be a class or namespace");
240+
return;
240241
}
241-
}
242242

243-
auto className = get_name(control(), ast->unqualifiedId);
244-
auto templateId = ast_cast<SimpleTemplateIdAST>(ast->unqualifiedId);
245-
if (templateId) {
246-
className = templateId->identifier;
247-
}
243+
setScope(static_cast<ScopeSymbol*>(parent));
244+
};
248245

249-
auto location = ast->classLoc;
250-
if (templateId) {
251-
location = templateId->identifierLoc;
252-
} else if (ast->unqualifiedId) {
253-
location = ast->unqualifiedId->firstSourceLocation();
254-
}
246+
auto check_template_specialization = [&] {
247+
auto templateId = ast_cast<SimpleTemplateIdAST>(ast->unqualifiedId);
248+
if (!templateId) return false;
255249

256-
ClassSymbol* primaryTemplate = nullptr;
250+
const auto location = templateId->identifierLoc;
257251

258-
if (templateId && scope()->isTemplateParameters()) {
259-
for (auto candidate : declaringScope()->find(className) | views::classes) {
260-
primaryTemplate = candidate;
252+
ClassSymbol* primaryTemplateSymbol = nullptr;
253+
254+
for (auto candidate :
255+
declaringScope()->find(templateId->identifier) | views::classes) {
256+
primaryTemplateSymbol = candidate;
261257
break;
262258
}
263259

264-
if (!primaryTemplate) {
260+
if (!primaryTemplateSymbol ||
261+
!primaryTemplateSymbol->templateParameters()) {
265262
error(location, std::format("specialization of undeclared template '{}'",
266263
templateId->identifier->name()));
264+
// return true;
267265
}
268-
}
266+
267+
std::vector<TemplateArgument> templateArguments;
268+
ClassSymbol* specialization = nullptr;
269+
270+
if (primaryTemplateSymbol) {
271+
templateArguments = ASTRewriter::make_substitution(
272+
unit_, primaryTemplateSymbol->templateDeclaration(),
273+
templateId->templateArgumentList);
274+
275+
specialization =
276+
primaryTemplateSymbol
277+
? primaryTemplateSymbol->findSpecialization(templateArguments)
278+
: nullptr;
279+
280+
if (specialization) {
281+
error(location, std::format("redefinition of specialization '{}'",
282+
templateId->identifier->name()));
283+
// return true;
284+
}
285+
}
286+
287+
const auto isUnion = ast->classKey == TokenKind::T_UNION;
288+
289+
auto classSymbol = control()->newClassSymbol(declaringScope(), location);
290+
ast->symbol = classSymbol;
291+
292+
classSymbol->setIsUnion(isUnion);
293+
classSymbol->setName(templateId->identifier);
294+
ast->symbol->setDeclaration(ast);
295+
ast->symbol->setFinal(ast->isFinal);
296+
297+
// if (declSpecs.templateHead) {
298+
// warning(location, "setting template head");
299+
// ast->symbol->setTemplateDeclaration(declSpecs.templateHead);
300+
// }
301+
302+
declSpecs.setTypeSpecifier(ast);
303+
declSpecs.setType(ast->symbol->type());
304+
305+
if (primaryTemplateSymbol) {
306+
primaryTemplateSymbol->addSpecialization(std::move(templateArguments),
307+
classSymbol);
308+
}
309+
310+
return true;
311+
};
312+
313+
check_optional_nested_name_specifier();
314+
315+
if (check_template_specialization()) return;
316+
317+
// get the component anme
318+
const Identifier* className = nullptr;
319+
if (auto nameId = ast_cast<NameIdAST>(ast->unqualifiedId))
320+
className = nameId->identifier;
321+
322+
const auto location = ast->unqualifiedId
323+
? ast->unqualifiedId->firstSourceLocation()
324+
: ast->classLoc;
269325

270326
ClassSymbol* classSymbol = nullptr;
271327

@@ -277,6 +333,9 @@ void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
277333
}
278334

279335
if (classSymbol && classSymbol->isComplete()) {
336+
// not a template-id, but a class with the same name already exists
337+
error(location,
338+
std::format("redefinition of class '{}'", to_string(className)));
280339
classSymbol = nullptr;
281340
}
282341

@@ -285,29 +344,22 @@ void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
285344
classSymbol = control()->newClassSymbol(scope(), location);
286345
classSymbol->setIsUnion(isUnion);
287346
classSymbol->setName(className);
288-
classSymbol->setTemplateParameters(templateParameters);
289347

290-
if (!primaryTemplate) {
291-
declaringScope()->addSymbol(classSymbol);
292-
} else {
293-
std::vector<TemplateArgument> arguments;
294-
// TODO: parse template arguments
295-
primaryTemplate->addSpecialization(arguments, classSymbol);
296-
}
348+
declaringScope()->addSymbol(classSymbol);
297349
}
298350

299-
classSymbol->setDeclaration(ast);
351+
ast->symbol = classSymbol;
352+
353+
ast->symbol->setDeclaration(ast);
300354

301355
if (declSpecs.templateHead) {
302-
classSymbol->setTemplateDeclaration(declSpecs.templateHead);
356+
ast->symbol->setTemplateDeclaration(declSpecs.templateHead);
303357
}
304358

305-
classSymbol->setFinal(ast->isFinal);
306-
307-
ast->symbol = classSymbol;
359+
ast->symbol->setFinal(ast->isFinal);
308360

309361
declSpecs.setTypeSpecifier(ast);
310-
declSpecs.setType(classSymbol->type());
362+
declSpecs.setType(ast->symbol->type());
311363
}
312364

313365
void Binder::complete(ClassSpecifierAST* ast) {
@@ -417,7 +469,6 @@ auto Binder::declareTypeAlias(SourceLocation identifierLoc, TypeIdAST* typeId,
417469
symbol->setName(name);
418470

419471
if (typeId) symbol->setType(typeId->type);
420-
symbol->setTemplateParameters(currentTemplateParameters());
421472

422473
if (auto classType = type_cast<ClassType>(symbol->type())) {
423474
auto classSymbol = classType->symbol();
@@ -565,7 +616,6 @@ void Binder::bind(ConceptDefinitionAST* ast) {
565616

566617
auto symbol = control()->newConceptSymbol(scope(), ast->identifierLoc);
567618
symbol->setName(ast->identifier);
568-
symbol->setTemplateParameters(templateParameters);
569619

570620
declaringScope()->addSymbol(symbol);
571621
}
@@ -708,7 +758,6 @@ auto Binder::declareFunction(DeclaratorAST* declarator, const Decl& decl)
708758
applySpecifiers(functionSymbol, decl.specs);
709759
functionSymbol->setName(name);
710760
functionSymbol->setType(type);
711-
functionSymbol->setTemplateParameters(currentTemplateParameters());
712761

713762
if (isConstructor(functionSymbol)) {
714763
auto enclosingClass = symbol_cast<ClassSymbol>(scope());
@@ -775,7 +824,6 @@ auto Binder::declareVariable(DeclaratorAST* declarator, const Decl& decl)
775824
applySpecifiers(symbol, decl.specs);
776825
symbol->setName(name);
777826
symbol->setType(type);
778-
symbol->setTemplateParameters(currentTemplateParameters());
779827
declaringScope()->addSymbol(symbol);
780828
return symbol;
781829
}
@@ -890,13 +938,29 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier,
890938
}
891939

892940
void Binder::bind(IdExpressionAST* ast) {
893-
if (ast->unqualifiedId) {
894-
auto name = get_name(control(), ast->unqualifiedId);
895-
const Name* componentName = name;
896-
if (auto templateId = name_cast<TemplateId>(name))
897-
componentName = templateId->name();
898-
ast->symbol = Lookup{scope()}(ast->nestedNameSpecifier, componentName);
941+
if (!ast->unqualifiedId) {
942+
error(ast->firstSourceLocation(),
943+
"expected an unqualified identifier in id expression");
944+
return;
899945
}
946+
947+
auto name = get_name(control(), ast->unqualifiedId);
948+
949+
const Name* componentName = name;
950+
951+
if (auto templateId = name_cast<TemplateId>(name)) {
952+
componentName = templateId->name();
953+
}
954+
955+
if (ast->nestedNameSpecifier) {
956+
if (!ast->nestedNameSpecifier->symbol) {
957+
error(ast->nestedNameSpecifier->firstSourceLocation(),
958+
"nested name specifier must be a class or namespace");
959+
return;
960+
}
961+
}
962+
963+
ast->symbol = Lookup{scope()}(ast->nestedNameSpecifier, componentName);
900964
}
901965

902966
auto Binder::getFunction(ScopeSymbol* scope, const Name* name, const Type* type)

src/parser/cxx/external_name_encoder.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,31 @@ struct ExternalNameEncoder::EncodeUnqualifiedName {
329329
ExternalNameEncoder& encoder;
330330
Symbol* symbol = nullptr;
331331

332+
void encodeTemplateArguments(Symbol* symbol) {
333+
if (!symbol) return;
334+
335+
std::span<const TemplateArgument> args;
336+
337+
if (auto classSymbol = symbol_cast<ClassSymbol>(symbol)) {
338+
args = classSymbol->templateArguments();
339+
}
340+
341+
if (args.empty()) return;
342+
343+
encoder.out("I");
344+
345+
for (const auto& arg : args) {
346+
if (auto sym = std::get_if<Symbol*>(&arg)) {
347+
auto type = (*sym)->type();
348+
encoder.encodeType(type);
349+
} else {
350+
cxx_runtime_error("template argument not supported yet");
351+
}
352+
}
353+
354+
encoder.out("E");
355+
}
356+
332357
void operator()(const Identifier* id) {
333358
if (auto function = symbol_cast<FunctionSymbol>(symbol)) {
334359
if (function->isConstructor()) {
@@ -338,6 +363,7 @@ struct ExternalNameEncoder::EncodeUnqualifiedName {
338363
}
339364

340365
out(std::format("{}{}", id->name().length(), id->name()));
366+
encodeTemplateArguments(symbol);
341367
}
342368

343369
void operator()(const OperatorId* name) {

0 commit comments

Comments
 (0)