diff --git a/src/parser/cxx/ast.h b/src/parser/cxx/ast.h index 120ad799..6a41e604 100644 --- a/src/parser/cxx/ast.h +++ b/src/parser/cxx/ast.h @@ -2999,6 +2999,7 @@ class ElaboratedTypeSpecifierAST final : public SpecifierAST { UnqualifiedIdAST* unqualifiedId = nullptr; TokenKind classKey = TokenKind::T_EOF_SYMBOL; bool isTemplateIntroduced = false; + Symbol* symbol = nullptr; void accept(ASTVisitor* visitor) override { visitor->visit(this); } diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index 9723f16b..31fa7ea9 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -38,6 +38,7 @@ #include #include #include +#include namespace cxx { @@ -6568,33 +6569,56 @@ auto Parser::parse_elaborated_type_specifier_helper( SourceLocation templateLoc; const auto isTemplateIntroduced = match(TokenKind::T_TEMPLATE, templateLoc); - UnqualifiedIdAST* unqualifiedId = nullptr; + auto ast = make_node(pool_); + yyast = ast; + + ast->classLoc = classLoc; + ast->attributeList = attributes; + ast->nestedNameSpecifier = nestedNameSpecifier; + ast->templateLoc = templateLoc; + ast->classKey = unit->tokenKind(classLoc); + ast->isTemplateIntroduced = isTemplateIntroduced; + const auto loc = currentLocation(); if (lookat(TokenKind::T_IDENTIFIER, TokenKind::T_LESS)) { if (SimpleTemplateIdAST* templateId = nullptr; parse_simple_template_id( templateId, nestedNameSpecifier, isTemplateIntroduced)) { - unqualifiedId = templateId; + ast->unqualifiedId = templateId; } else { parse_error(loc, "expected a template-id"); } } else if (NameIdAST* nameId = nullptr; parse_name_id(nameId)) { - unqualifiedId = nameId; + ast->unqualifiedId = nameId; + auto symbol = Lookup{scope_}(nestedNameSpecifier, nameId->identifier); - } else { - parse_error("expected a name"); - } - auto ast = make_node(pool_); - yyast = ast; + auto class_symbols_view = std::views::filter(&Symbol::isClass); - ast->classLoc = classLoc; - ast->attributeList = attributes; - ast->nestedNameSpecifier = nestedNameSpecifier; - ast->templateLoc = templateLoc; - ast->unqualifiedId = unqualifiedId; - ast->classKey = unit->tokenKind(classLoc); - ast->isTemplateIntroduced = isTemplateIntroduced; + auto enum_symbols_view = std::views::filter([](Symbol* symbol) { + return symbol->isEnum() || symbol->isScopedEnum(); + }); + + if (ast->classKey == TokenKind::T_CLASS || + ast->classKey == TokenKind::T_STRUCT || + ast->classKey == TokenKind::T_UNION) { + for (auto symbol : SymbolChainView(symbol) | class_symbols_view) { + if (symbol->name() != nameId->identifier) continue; + ast->symbol = symbol; + specs.type = symbol->type(); + break; + } + } else if (ast->classKey == TokenKind::T_ENUM) { + for (auto symbol : SymbolChainView(symbol) | enum_symbols_view) { + if (symbol->name() != nameId->identifier) continue; + ast->symbol = symbol; + specs.type = symbol->type(); + break; + } + } + } else { + parse_error(loc, "expected a name"); + } return true; } diff --git a/src/parser/cxx/symbols.h b/src/parser/cxx/symbols.h index 3608fcff..69afd7f0 100644 --- a/src/parser/cxx/symbols.h +++ b/src/parser/cxx/symbols.h @@ -150,6 +150,46 @@ class Symbol { SourceLocation location_; }; +class SymbolChainIterator { + public: + using difference_type = std::ptrdiff_t; + using value_type = Symbol*; + + SymbolChainIterator() = default; + explicit SymbolChainIterator(Symbol* symbol) : symbol_(symbol) {} + + auto operator*() const -> Symbol* { return symbol_; } + + auto operator++() -> SymbolChainIterator& { + symbol_ = symbol_->next(); + return *this; + } + + auto operator++(int) -> SymbolChainIterator { + auto it = *this; + ++*this; + return it; + } + + auto operator==(const SymbolChainIterator&) const -> bool = default; + + private: + Symbol* symbol_ = nullptr; +}; + +static_assert(std::forward_iterator); + +class SymbolChainView : public std::ranges::view_interface { + public: + explicit SymbolChainView(Symbol* symbol) : begin_{symbol} {} + + auto begin() const -> SymbolChainIterator { return begin_; } + auto end() const -> SymbolChainIterator { return SymbolChainIterator(); } + + private: + SymbolChainIterator begin_; +}; + class ScopedSymbol : public Symbol { public: ScopedSymbol(SymbolKind kind, Scope* enclosingScope);