diff --git a/src/parser/cxx/name_lookup.cc b/src/parser/cxx/name_lookup.cc index ccb89565..a5eb79a2 100644 --- a/src/parser/cxx/name_lookup.cc +++ b/src/parser/cxx/name_lookup.cc @@ -53,15 +53,40 @@ auto Lookup::qualifiedLookup(Symbol* scopedSymbol, const Name* name) const case SymbolKind::kNamespace: return qualifiedLookup( symbol_cast(scopedSymbol)->scope(), name); + case SymbolKind::kClass: return qualifiedLookup(symbol_cast(scopedSymbol)->scope(), name); + case SymbolKind::kEnum: return qualifiedLookup(symbol_cast(scopedSymbol)->scope(), name); + case SymbolKind::kScopedEnum: return qualifiedLookup( symbol_cast(scopedSymbol)->scope(), name); + + case SymbolKind::kTypeAlias: { + auto alias = symbol_cast(scopedSymbol); + + if (auto classType = type_cast(alias->type())) { + auto classSymbol = classType->symbol(); + return qualifiedLookup(classSymbol->scope(), name); + } + + if (auto enumType = type_cast(alias->type())) { + auto enumSymbol = enumType->symbol(); + return qualifiedLookup(enumSymbol->scope(), name); + } + + if (auto scopedEnumType = type_cast(alias->type())) { + auto scopedEnumSymbol = scopedEnumType->symbol(); + return qualifiedLookup(scopedEnumSymbol->scope(), name); + } + + return nullptr; + } + default: return nullptr; } // switch diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index 7f8c5da6..362b6021 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -5272,29 +5272,33 @@ auto Parser::parse_static_assert_declaration(DeclarationAST*& yyast) -> bool { expect(TokenKind::T_SEMICOLON, ast->semicolonLoc); - bool value = false; + if (!inTemplate_) { + // not in a template context - if (constValue.has_value()) { - value = visit(to_bool, *constValue); - } + bool value = false; - if (!value && config_.staticAssert) { - SourceLocation loc = ast->firstSourceLocation(); + if (constValue.has_value()) { + value = visit(to_bool, *constValue); + } - if (!ast->expression || !constValue.has_value()) { - parse_error( - loc, - "static assertion expression is not an integral constant expression"); - } else { - if (ast->literalLoc) - loc = ast->literalLoc; - else if (ast->expression) - ast->expression->firstSourceLocation(); + if (!value && config_.staticAssert) { + SourceLocation loc = ast->firstSourceLocation(); + + if (!ast->expression || !constValue.has_value()) { + parse_error(loc, + "static assertion expression is not an integral constant " + "expression"); + } else { + if (ast->literalLoc) + loc = ast->literalLoc; + else if (ast->expression) + ast->expression->firstSourceLocation(); - std::string message = - ast->literal ? ast->literal->value() : "static assert failed"; + std::string message = + ast->literal ? ast->literal->value() : "static assert failed"; - unit->error(loc, std::move(message)); + unit->error(loc, std::move(message)); + } } } @@ -9431,66 +9435,11 @@ auto Parser::parse_class_specifier( expect(TokenKind::T_RBRACE, ast->rbraceLoc); } - if (!is_template(classSymbol)) { - int offset = 0; - int alignment = 1; - - for (auto base : classSymbol->baseClasses()) { - auto baseClassSymbol = symbol_cast(base->symbol()); - - if (!baseClassSymbol) { - if (config_.checkTypes) { - parse_error(base->location(), std::format("base class '{}' not found", - to_string(base->name()))); - } - continue; - } - - offset = align_to(offset, baseClassSymbol->alignment()); - offset += baseClassSymbol->sizeInBytes(); - alignment = std::max(alignment, baseClassSymbol->alignment()); - } - - for (auto member : classSymbol->scope()->symbols()) { - auto field = symbol_cast(member); - if (!field) continue; - if (field->isStatic()) continue; - - if (!field->alignment()) { - if (config_.checkTypes) { - parse_error(field->location(), - std::format("alignment of incomplete type '{}'", - to_string(field->type(), field->name()))); - } - continue; - } - - auto size = control_->memoryLayout()->sizeOf(field->type()); - - if (!size.has_value()) { - if (config_.checkTypes) { - parse_error(field->location(), - std::format("size of incomplete type '{}'", - to_string(field->type(), field->name()))); - } - continue; - } - - if (classSymbol->isUnion()) { - offset = std::max(offset, int(size.value())); - } else { - offset = align_to(offset, field->alignment()); - field->setOffset(offset); - offset += size.value(); - } - - alignment = std::max(alignment, field->alignment()); + if (!inTemplate_) { + auto status = classSymbol->buildClassLayout(control_); + if (!status.has_value() && config_.checkTypes) { + parse_error(classSymbol->location(), status.error()); } - - offset = align_to(offset, alignment); - - classSymbol->setAlignment(alignment); - classSymbol->setSizeInBytes(offset); } classSymbol->setComplete(true); @@ -11506,7 +11455,18 @@ void Parser::completePendingFunctionDefinitions() { } } -void Parser::setScope(Scope* scope) { scope_ = scope; } +void Parser::setScope(Scope* scope) { + scope_ = scope; + + inTemplate_ = false; + + for (auto current = scope_; current; current = current->parent()) { + if (current->isTemplateParametersScope()) { + inTemplate_ = true; + break; + } + } +} void Parser::setScope(ScopedSymbol* symbol) { setScope(symbol->scope()); } diff --git a/src/parser/cxx/parser.h b/src/parser/cxx/parser.h index c9724fbe..ad798d1a 100644 --- a/src/parser/cxx/parser.h +++ b/src/parser/cxx/parser.h @@ -911,6 +911,7 @@ class Parser final { int templateParameterCount_ = 0; bool didAcceptCompletionToken_ = false; std::vector pendingFunctionDefinitions_; + bool inTemplate_ = false; template class CachedAST { diff --git a/src/parser/cxx/symbol_instantiation.cc b/src/parser/cxx/symbol_instantiation.cc index 18424343..d56c7998 100644 --- a/src/parser/cxx/symbol_instantiation.cc +++ b/src/parser/cxx/symbol_instantiation.cc @@ -23,9 +23,12 @@ // cxx #include #include +#include +#include #include #include #include +#include #include #include @@ -185,9 +188,19 @@ auto SymbolInstantiation::findOrCreateReplacement(Symbol* symbol) -> Symbol* { replacements_[symbol] = newSymbol; auto enclosingSymbol = replacement(symbol->enclosingSymbol()); + newSymbol->setEnclosingScope(enclosingSymbol->scope()); + if (symbol->type()) { newSymbol->setType(visit(VisitType{*this}, symbol->type())); + + auto field = symbol_cast(newSymbol); + + if (field) { + auto memoryLayout = control()->memoryLayout(); + auto alignment = memoryLayout->alignmentOf(newSymbol->type()); + field->setAlignment(alignment.value_or(0)); + } } newSymbol->setName(symbol->name()); @@ -217,6 +230,10 @@ auto SymbolInstantiation::VisitSymbol::operator()(ClassSymbol* symbol) auto newSymbol = self.replacement(symbol); newSymbol->setFlags(symbol->flags()); + if (symbol->isComplete()) { + newSymbol->setComplete(true); + } + if (symbol != self.current_) { newSymbol->setTemplateParameters( self.instantiate(symbol->templateParameters())); @@ -226,14 +243,19 @@ auto SymbolInstantiation::VisitSymbol::operator()(ClassSymbol* symbol) auto newBaseClass = self.instantiate(baseClass); newSymbol->addBaseClass(newBaseClass); } + for (auto ctor : symbol->constructors()) { auto newCtor = self.instantiate(ctor); newSymbol->addConstructor(newCtor); } + for (auto member : views::members(symbol)) { auto newMember = self.instantiate(member); newSymbol->addMember(newMember); } + + auto status = newSymbol->buildClassLayout(self.control()); + return newSymbol; } diff --git a/src/parser/cxx/symbols.cc b/src/parser/cxx/symbols.cc index 84fdc67c..0042df39 100644 --- a/src/parser/cxx/symbols.cc +++ b/src/parser/cxx/symbols.cc @@ -22,8 +22,15 @@ // cxx #include +#include +#include +#include #include +#include #include +#include + +#include namespace cxx { @@ -184,7 +191,7 @@ void ClassSymbol::setSizeInBytes(int sizeInBytes) { sizeInBytes_ = sizeInBytes; } -auto ClassSymbol::alignment() const -> int { return alignment_; } +auto ClassSymbol::alignment() const -> int { return std::max(alignment_, 1); } void ClassSymbol::setAlignment(int alignment) { alignment_ = alignment; } @@ -241,6 +248,64 @@ void ClassSymbol::addSpecialization(std::vector arguments, templateInfo_->addSpecialization(std::move(arguments), specialization); } +auto ClassSymbol::buildClassLayout(Control* control) + -> std::expected { + int offset = 0; + int alignment = 1; + + auto memoryLayout = control->memoryLayout(); + + for (auto base : baseClasses()) { + auto baseClassSymbol = symbol_cast(base->symbol()); + + if (!baseClassSymbol) { + return std::unexpected( + std::format("base class '{}' not found", to_string(base->name()))); + } + + offset = align_to(offset, baseClassSymbol->alignment()); + offset += baseClassSymbol->sizeInBytes(); + alignment = std::max(alignment, baseClassSymbol->alignment()); + } + + for (auto member : scope()->symbols()) { + auto field = symbol_cast(member); + if (!field) continue; + if (field->isStatic()) continue; + + if (!field->alignment()) { + return std::unexpected( + std::format("alignment of incomplete type '{}'", + to_string(field->type(), field->name()))); + } + + auto size = memoryLayout->sizeOf(field->type()); + + if (!size.has_value()) { + return std::unexpected( + std::format("size of incomplete type '{}'", + to_string(field->type(), field->name()))); + } + + if (isUnion()) { + offset = std::max(offset, int(size.value())); + } else { + offset = align_to(offset, field->alignment()); + field->setOffset(offset); + offset += size.value(); + } + + alignment = std::max(alignment, field->alignment()); + } + + offset = align_to(offset, alignment); + + setAlignment(alignment); + setSizeInBytes(offset); + + return true; +} + EnumSymbol::EnumSymbol(Scope* enclosingScope) : ScopedSymbol(Kind, enclosingScope) {} diff --git a/src/parser/cxx/symbols.h b/src/parser/cxx/symbols.h index bb6c4cb2..cdb127e7 100644 --- a/src/parser/cxx/symbols.h +++ b/src/parser/cxx/symbols.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -298,6 +299,9 @@ class ClassSymbol final : public ScopedSymbol { return templateSepcializationIndex_; } + [[nodiscard]] auto buildClassLayout(Control* control) + -> std::expected; + private: [[nodiscard]] auto hasBaseClass(Symbol* symbol, std::unordered_set&) const diff --git a/tests/unit_tests/sema/template_class_04.cc b/tests/unit_tests/sema/template_class_04.cc new file mode 100644 index 00000000..224868a9 --- /dev/null +++ b/tests/unit_tests/sema/template_class_04.cc @@ -0,0 +1,28 @@ +// RUN: %cxx -verify -ftemplates -fcheck %s + +template +struct X { + Key key; + Value value; + + void check() { static_assert(sizeof(*this)); } +}; + +auto main() -> int { + using U = X; + U u; + + static_assert(sizeof(U) == 8); + static_assert(sizeof(u) == 8); + + static_assert(__builtin_offsetof(U, key) == 0); + static_assert(__builtin_offsetof(U, value) == 4); + + static_assert(sizeof(X::key) == 1); + static_assert(sizeof(X::value) == 4); + + static_assert(sizeof(U::key) == 1); + static_assert(sizeof(U::value) == 4); + + return 0; +} \ No newline at end of file