Skip to content

Commit 5eebfc0

Browse files
authored
Merge pull request #68338 from tshortli/inherited-types-refactor
AST: Refactor representation of inherited types
2 parents bdc6a27 + 01ecd81 commit 5eebfc0

35 files changed

+260
-207
lines changed

include/swift/AST/Decl.h

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "swift/AST/RequirementSignature.h"
3636
#include "swift/AST/StorageImpl.h"
3737
#include "swift/AST/TypeAlignments.h"
38+
#include "swift/AST/TypeResolutionStage.h"
3839
#include "swift/AST/TypeWalker.h"
3940
#include "swift/AST/Types.h"
4041
#include "swift/Basic/ArrayRefView.h"
@@ -1547,6 +1548,53 @@ struct InheritedEntry : public TypeLoc {
15471548
: TypeLoc(typeLoc), isUnchecked(isUnchecked) { }
15481549
};
15491550

1551+
/// A wrapper for the collection of inherited types for either a `TypeDecl` or
1552+
/// an `ExtensionDecl`.
1553+
class InheritedTypes {
1554+
llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *> Decl;
1555+
ArrayRef<InheritedEntry> Entries;
1556+
1557+
public:
1558+
InheritedTypes(
1559+
llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *> decl);
1560+
InheritedTypes(const class Decl *decl);
1561+
InheritedTypes(const TypeDecl *typeDecl);
1562+
InheritedTypes(const ExtensionDecl *extensionDecl);
1563+
1564+
bool empty() const { return Entries.empty(); }
1565+
size_t size() const { return Entries.size(); }
1566+
IntRange<size_t> const getIndices() { return indices(Entries); }
1567+
1568+
/// Returns the `TypeRepr *` for the entry of the inheritance clause at the
1569+
/// given index.
1570+
TypeRepr *getTypeRepr(unsigned i) const { return Entries[i].getTypeRepr(); }
1571+
1572+
/// Returns the `Type` for the entry of the inheritance clause at the given
1573+
/// index, resolved at the given stage, or `Type()` if resolution fails.
1574+
Type getResolvedType(unsigned i, TypeResolutionStage stage =
1575+
TypeResolutionStage::Interface) const;
1576+
1577+
/// Returns the underlying array of inherited type entries.
1578+
///
1579+
/// NOTE: The `Type` associated with an entry may not be resolved yet.
1580+
ArrayRef<InheritedEntry> getEntries() const { return Entries; }
1581+
1582+
/// Returns the entry of the inheritance clause at the given index.
1583+
///
1584+
/// NOTE: The `Type` associated with the entry may not be resolved yet.
1585+
const InheritedEntry &getEntry(unsigned i) const { return Entries[i]; }
1586+
1587+
/// Returns the source location of the beginning of the inheritance clause.
1588+
SourceLoc getStartLoc() const {
1589+
return getEntries().front().getSourceRange().Start;
1590+
}
1591+
1592+
/// Returns the source location of the end of the inheritance clause.
1593+
SourceLoc getEndLoc() const {
1594+
return getEntries().back().getSourceRange().End;
1595+
}
1596+
};
1597+
15501598
/// ExtensionDecl - This represents a type extension containing methods
15511599
/// associated with the type. This is not a ValueDecl and has no Type because
15521600
/// there are no runtime values of the Extension's type.
@@ -1582,6 +1630,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
15821630
friend class MemberLookupTable;
15831631
friend class ConformanceLookupTable;
15841632
friend class IterableDeclContext;
1633+
friend class InheritedTypes;
15851634

15861635
ExtensionDecl(SourceLoc extensionLoc, TypeRepr *extendedType,
15871636
ArrayRef<InheritedEntry> inherited,
@@ -1662,7 +1711,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
16621711

16631712
/// Retrieve the set of protocols that this type inherits (i.e,
16641713
/// explicitly conforms to).
1665-
ArrayRef<InheritedEntry> getInherited() const { return Inherited; }
1714+
InheritedTypes getInherited() const { return InheritedTypes(this); }
16661715

16671716
void setInherited(ArrayRef<InheritedEntry> i) { Inherited = i; }
16681717

@@ -2994,6 +3043,8 @@ class TypeDecl : public ValueDecl {
29943043
ArrayRef<InheritedEntry> inherited) :
29953044
ValueDecl(K, context, name, NameLoc), Inherited(inherited) {}
29963045

3046+
friend class InheritedTypes;
3047+
29973048
public:
29983049
Identifier getName() const { return getBaseIdentifier(); }
29993050

@@ -3009,7 +3060,7 @@ class TypeDecl : public ValueDecl {
30093060

30103061
/// Retrieve the set of protocols that this type inherits (i.e,
30113062
/// explicitly conforms to).
3012-
ArrayRef<InheritedEntry> getInherited() const { return Inherited; }
3063+
InheritedTypes getInherited() const { return InheritedTypes(this); }
30133064

30143065
void setInherited(ArrayRef<InheritedEntry> i) { Inherited = i; }
30153066

include/swift/AST/NameLookup.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -620,17 +620,6 @@ SelfBounds getSelfBoundsFromWhereClause(
620620
/// given protocol or protocol extension.
621621
SelfBounds getSelfBoundsFromGenericSignature(const ExtensionDecl *extDecl);
622622

623-
/// Retrieve the TypeLoc at the given \c index from among the set of
624-
/// type declarations that are directly "inherited" by the given declaration.
625-
inline const TypeLoc &getInheritedTypeLocAtIndex(
626-
llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *> decl,
627-
unsigned index) {
628-
if (auto typeDecl = decl.dyn_cast<const TypeDecl *>())
629-
return typeDecl->getInherited()[index];
630-
631-
return decl.get<const ExtensionDecl *>()->getInherited()[index];
632-
}
633-
634623
namespace namelookup {
635624

636625
/// Searches through statements and patterns for local variable declarations.

include/swift/AST/TypeCheckRequests.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class InheritedTypeRequest
9696
llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *> decl,
9797
unsigned index, TypeResolutionStage stage) const;
9898

99+
const TypeLoc &getTypeLoc() const;
100+
99101
public:
100102
// Source location
101103
SourceLoc getNearestLoc() const;

lib/AST/ASTDumper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,11 @@ namespace {
569569
PrintWithColorRAII(OS, DeclModifierColor) << " trailing_semi";
570570
}
571571

572-
void printInherited(ArrayRef<InheritedEntry> Inherited) {
572+
void printInherited(InheritedTypes Inherited) {
573573
if (Inherited.empty())
574574
return;
575575
OS << " inherits: ";
576-
interleave(Inherited,
576+
interleave(Inherited.getEntries(),
577577
[&](InheritedEntry Super) { Super.getType().print(OS); },
578578
[&] { OS << ", "; });
579579
}

lib/AST/ASTPrinter.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,8 +2954,8 @@ static bool usesFeatureRethrowsProtocol(
29542954
return false;
29552955

29562956
// Check an inheritance clause for a marker protocol.
2957-
auto checkInherited = [&](ArrayRef<InheritedEntry> inherited) -> bool {
2958-
for (const auto &inheritedEntry : inherited) {
2957+
auto checkInherited = [&](InheritedTypes inherited) -> bool {
2958+
for (const auto &inheritedEntry : inherited.getEntries()) {
29592959
if (auto inheritedType = inheritedEntry.getType()) {
29602960
if (inheritedType->isExistentialType()) {
29612961
auto layout = inheritedType->getExistentialLayout();
@@ -7681,15 +7681,10 @@ void
76817681
swift::getInheritedForPrinting(
76827682
const Decl *decl, const PrintOptions &options,
76837683
llvm::SmallVectorImpl<InheritedEntry> &Results) {
7684-
ArrayRef<InheritedEntry> inherited;
7685-
if (auto td = dyn_cast<TypeDecl>(decl)) {
7686-
inherited = td->getInherited();
7687-
} else if (auto ed = dyn_cast<ExtensionDecl>(decl)) {
7688-
inherited = ed->getInherited();
7689-
}
7684+
InheritedTypes inherited = InheritedTypes(decl);
76907685

76917686
// Collect explicit inherited types.
7692-
for (auto entry: inherited) {
7687+
for (auto entry : inherited.getEntries()) {
76937688
if (auto ty = entry.getType()) {
76947689
bool foundUnprintable = ty.findIf([&](Type subTy) {
76957690
if (auto aliasTy = dyn_cast<TypeAliasType>(subTy.getPointer()))
@@ -7831,7 +7826,7 @@ void GenericParamList::print(ASTPrinter &Printer,
78317826
if (!P->getInherited().empty()) {
78327827
Printer << " : ";
78337828

7834-
auto loc = P->getInherited()[0];
7829+
auto loc = P->getInherited().getEntry(0);
78357830
if (willUseTypeReprPrinting(loc, nullptr, PO)) {
78367831
loc.getTypeRepr()->print(Printer, PO);
78377832
} else {

lib/AST/ASTWalker.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
164164
if (auto *typeRepr = ED->getExtendedTypeRepr())
165165
if (doIt(typeRepr))
166166
return true;
167-
for (auto &Inherit : ED->getInherited()) {
168-
if (auto *const TyR = Inherit.getTypeRepr())
167+
auto inheritedTypes = ED->getInherited();
168+
for (auto i : inheritedTypes.getIndices()) {
169+
if (auto *const TyR = inheritedTypes.getTypeRepr(i))
169170
if (doIt(TyR))
170171
return true;
171172
}
@@ -276,7 +277,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
276277
}
277278

278279
bool visitGenericTypeParamDecl(GenericTypeParamDecl *GTPD) {
279-
for (const auto &Inherit: GTPD->getInherited()) {
280+
for (const auto &Inherit : GTPD->getInherited().getEntries()) {
280281
if (auto *const TyR = Inherit.getTypeRepr())
281282
if (doIt(TyR))
282283
return true;
@@ -285,7 +286,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
285286
}
286287

287288
bool visitAssociatedTypeDecl(AssociatedTypeDecl *ATD) {
288-
for (const auto &Inherit: ATD->getInherited()) {
289+
for (const auto &Inherit : ATD->getInherited().getEntries()) {
289290
if (auto *const TyR = Inherit.getTypeRepr())
290291
if (doIt(TyR))
291292
return true;
@@ -310,9 +311,10 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
310311

311312
bool WalkGenerics = visitGenericParamListIfNeeded(NTD);
312313

313-
for (const auto &Inherit : NTD->getInherited()) {
314-
if (auto *const TyR = Inherit.getTypeRepr())
315-
if (doIt(Inherit.getTypeRepr()))
314+
auto inheritedTypes = NTD->getInherited();
315+
for (auto i : inheritedTypes.getIndices()) {
316+
if (auto *const TyR = inheritedTypes.getTypeRepr(i))
317+
if (doIt(TyR))
316318
return true;
317319
}
318320

lib/AST/ConformanceLookupTable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ void ConformanceLookupTable::inheritConformances(ClassDecl *classDecl,
226226
if (superclassLoc.isValid())
227227
return superclassLoc;
228228

229-
for (const auto &inherited : classDecl->getInherited()) {
229+
for (const auto &inherited : classDecl->getInherited().getEntries()) {
230230
if (auto inheritedType = inherited.getType()) {
231231
if (inheritedType->getClassOrBoundGenericClass()) {
232232
superclassLoc = inherited.getSourceRange().Start;

lib/AST/Decl.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,47 @@ InheritedEntry::InheritedEntry(const TypeLoc &typeLoc)
15271527
isUnchecked = typeRepr->findUncheckedAttrLoc().isValid();
15281528
}
15291529

1530+
InheritedTypes::InheritedTypes(
1531+
llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *> decl)
1532+
: Decl(decl) {
1533+
if (auto *typeDecl = decl.dyn_cast<const TypeDecl *>()) {
1534+
Entries = typeDecl->Inherited;
1535+
} else {
1536+
Entries = decl.get<const ExtensionDecl *>()->Inherited;
1537+
}
1538+
}
1539+
1540+
InheritedTypes::InheritedTypes(const class Decl *decl) {
1541+
if (auto typeDecl = dyn_cast<TypeDecl>(decl)) {
1542+
Decl = typeDecl;
1543+
Entries = typeDecl->Inherited;
1544+
} else if (auto extensionDecl = dyn_cast<ExtensionDecl>(decl)) {
1545+
Decl = extensionDecl;
1546+
Entries = extensionDecl->Inherited;
1547+
} else {
1548+
Decl = nullptr;
1549+
Entries = ArrayRef<InheritedEntry>();
1550+
}
1551+
}
1552+
1553+
InheritedTypes::InheritedTypes(const TypeDecl *typeDecl) : Decl(typeDecl) {
1554+
Entries = typeDecl->Inherited;
1555+
}
1556+
1557+
InheritedTypes::InheritedTypes(const ExtensionDecl *extensionDecl)
1558+
: Decl(extensionDecl) {
1559+
Entries = extensionDecl->Inherited;
1560+
}
1561+
1562+
Type InheritedTypes::getResolvedType(unsigned i,
1563+
TypeResolutionStage stage) const {
1564+
ASTContext &ctx = Decl.is<const ExtensionDecl *>()
1565+
? Decl.get<const ExtensionDecl *>()->getASTContext()
1566+
: Decl.get<const TypeDecl *>()->getASTContext();
1567+
return evaluateOrDefault(ctx.evaluator, InheritedTypeRequest{Decl, i, stage},
1568+
Type());
1569+
}
1570+
15301571
ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc,
15311572
TypeRepr *extendedType,
15321573
ArrayRef<InheritedEntry> inherited,
@@ -5182,7 +5223,7 @@ SourceRange GenericTypeParamDecl::getSourceRange() const {
51825223
startLoc = eachLoc;
51835224

51845225
if (!getInherited().empty())
5185-
endLoc = getInherited().back().getSourceRange().End;
5226+
endLoc = getInherited().getEndLoc();
51865227

51875228
return {startLoc, endLoc};
51885229
}
@@ -5220,7 +5261,7 @@ SourceRange AssociatedTypeDecl::getSourceRange() const {
52205261
} else if (auto defaultDefinition = getDefaultDefinitionTypeRepr()) {
52215262
endLoc = defaultDefinition->getEndLoc();
52225263
} else if (!getInherited().empty()) {
5223-
endLoc = getInherited().back().getSourceRange().End;
5264+
endLoc = getInherited().getEndLoc();
52245265
} else {
52255266
endLoc = getNameLoc();
52265267
}

lib/AST/NameLookup.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3045,7 +3045,7 @@ DirectlyReferencedTypeDecls InheritedDeclsReferencedRequest::evaluate(
30453045
unsigned index) const {
30463046

30473047
// Prefer syntactic information when we have it.
3048-
const TypeLoc &typeLoc = getInheritedTypeLocAtIndex(decl, index);
3048+
const TypeLoc &typeLoc = InheritedTypes(decl).getEntry(index);
30493049
if (auto typeRepr = typeLoc.getTypeRepr()) {
30503050
// Figure out the context in which name lookup will occur.
30513051
DeclContext *dc;
@@ -3112,7 +3112,7 @@ SuperclassDeclRequest::evaluate(Evaluator &evaluator,
31123112
return classDecl;
31133113
}
31143114

3115-
for (unsigned i : indices(subject->getInherited())) {
3115+
for (unsigned i : subject->getInherited().getIndices()) {
31163116
// Find the inherited declarations referenced at this position.
31173117
auto inheritedTypes = evaluateOrDefault(evaluator,
31183118
InheritedDeclsReferencedRequest{subject, i}, {});
@@ -3592,8 +3592,8 @@ void swift::getDirectlyInheritedNominalTypeDecls(
35923592
// InheritedDeclsReferencedRequest to make this work.
35933593
SourceLoc loc;
35943594
SourceLoc uncheckedLoc;
3595-
if (TypeRepr *typeRepr = typeDecl ? typeDecl->getInherited()[i].getTypeRepr()
3596-
: extDecl->getInherited()[i].getTypeRepr()){
3595+
auto inheritedTypes = InheritedTypes(decl);
3596+
if (TypeRepr *typeRepr = inheritedTypes.getTypeRepr(i)) {
35973597
loc = typeRepr->getLoc();
35983598
uncheckedLoc = typeRepr->findUncheckedAttrLoc();
35993599
}
@@ -3608,17 +3608,15 @@ SmallVector<InheritedNominalEntry, 4>
36083608
swift::getDirectlyInheritedNominalTypeDecls(
36093609
llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *> decl,
36103610
bool &anyObject) {
3611-
auto typeDecl = decl.dyn_cast<const TypeDecl *>();
3612-
auto extDecl = decl.dyn_cast<const ExtensionDecl *>();
3611+
auto inheritedTypes = InheritedTypes(decl);
36133612

36143613
// Gather results from all of the inherited types.
3615-
unsigned numInherited = typeDecl ? typeDecl->getInherited().size()
3616-
: extDecl->getInherited().size();
36173614
SmallVector<InheritedNominalEntry, 4> result;
3618-
for (unsigned i : range(numInherited)) {
3615+
for (unsigned i : inheritedTypes.getIndices()) {
36193616
getDirectlyInheritedNominalTypeDecls(decl, i, result, anyObject);
36203617
}
36213618

3619+
auto *typeDecl = decl.dyn_cast<const TypeDecl *>();
36223620
auto *protoDecl = dyn_cast_or_null<ProtocolDecl>(typeDecl);
36233621
if (protoDecl == nullptr)
36243622
return result;

lib/AST/NameLookupRequests.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ namespace swift {
4040

4141
SourceLoc InheritedDeclsReferencedRequest::getNearestLoc() const {
4242
const auto &storage = getStorage();
43-
auto &typeLoc = getInheritedTypeLocAtIndex(std::get<0>(storage),
44-
std::get<1>(storage));
45-
return typeLoc.getLoc();
43+
auto inheritedTypes = InheritedTypes(std::get<0>(storage));
44+
return inheritedTypes.getEntry(std::get<1>(storage)).getLoc();
4645
}
4746

4847
//----------------------------------------------------------------------------//

0 commit comments

Comments
 (0)