Skip to content

Commit e36b95d

Browse files
committed
RequirementMachine: Split up list of associated types in ProtocolGraph
Store the protocol's direct associated types separately from the inherited associated types, since in a couple of places we only need the direct associated types. Also, factor out a new ProtocolGraph::compute() method that does all the steps in the right order.
1 parent 00eca7e commit e36b95d

File tree

5 files changed

+45
-45
lines changed

5 files changed

+45
-45
lines changed

lib/AST/RequirementMachine/EquivalenceClassMap.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -572,19 +572,7 @@ void EquivalenceClassMap::concretizeNestedTypesFromConcreteParent(
572572

573573
auto *concrete = conformance.getConcrete();
574574

575-
// We might have duplicates in the list due to diamond inheritance.
576-
// FIXME: Filter those out further upstream?
577-
// FIXME: This should actually be outside of the loop over the conforming protos...
578-
llvm::SmallDenseSet<AssociatedTypeDecl *, 4> visited;
579575
for (auto *assocType : assocTypes) {
580-
if (!visited.insert(assocType).second)
581-
continue;
582-
583-
// Get the actual protocol in case we inherited this associated type.
584-
auto *actualProto = assocType->getProtocol();
585-
if (actualProto != proto)
586-
continue;
587-
588576
if (DebugConcretizeNestedTypes) {
589577
llvm::dbgs() << "^^ " << "Looking up type witness for "
590578
<< proto->getName() << ":" << assocType->getName()

lib/AST/RequirementMachine/ProtocolGraph.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,9 @@ void ProtocolGraph::computeInheritedAssociatedTypes() {
109109
for (const auto *proto : llvm::reverse(Protocols)) {
110110
auto &info = Info[proto];
111111

112-
// We might inherit the same associated type multiple times due to
113-
// diamond inheritance, so make sure we only add each associated
114-
// type once.
115-
llvm::SmallDenseSet<const AssociatedTypeDecl *, 4> visited;
116-
117-
for (const auto *inherited : info.Inherited) {
118-
if (inherited == proto)
119-
continue;
120-
112+
for (const auto *inherited : info.AllInherited) {
121113
for (auto *inheritedType : getProtocolInfo(inherited).AssociatedTypes) {
122-
if (!visited.insert(inheritedType).second)
123-
continue;
124-
125-
// The 'if (inherited == proto)' above avoids a potential
126-
// iterator invalidation here, because we're updating
127-
// getProtocolInfo(proto).AssociatedTypes while iterating over
128-
// getProtocolInfo(inherited).AssociatedTypes.
129-
info.AssociatedTypes.push_back(inheritedType);
114+
info.InheritedAssociatedTypes.push_back(inheritedType);
130115
}
131116
}
132117
}
@@ -190,6 +175,14 @@ unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {
190175
return depth;
191176
}
192177

178+
/// Compute everything in the right order.
179+
void ProtocolGraph::compute() {
180+
computeTransitiveClosure();
181+
computeLinearOrder();
182+
computeInheritedProtocols();
183+
computeInheritedAssociatedTypes();
184+
}
185+
193186
/// Defines a linear order with the property that if a protocol P inherits
194187
/// from another protocol Q, then P < Q. (The converse cannot be true, since
195188
/// this is a linear order.)

lib/AST/RequirementMachine/ProtocolGraph.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ struct ProtocolInfo {
3434
/// itself. Computed by ProtocolGraph::computeInheritedProtocols().
3535
llvm::TinyPtrVector<const ProtocolDecl *> AllInherited;
3636

37-
/// Transitive closure of inherited associated types together with all
38-
/// associated types from the protocol itself. Computed by
39-
/// ProtocolGraph::computeInheritedAssociatedTypes().
37+
/// Associated types defined in the protocol itself.
4038
llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
4139

40+
/// Associated types from all inherited protocols, not including duplicates or
41+
/// those defined in the protocol itself. Computed by
42+
/// ProtocolGraph::computeInheritedAssociatedTypes().
43+
llvm::TinyPtrVector<AssociatedTypeDecl *> InheritedAssociatedTypes;
44+
4245
/// The protocol's requirement signature.
4346
ArrayRef<Requirement> Requirements;
4447

@@ -76,28 +79,36 @@ struct ProtocolInfo {
7679
/// referenced from a set of generic requirements.
7780
///
7881
/// Out-of-line methods are documented in ProtocolGraph.cpp.
79-
struct ProtocolGraph {
82+
class ProtocolGraph {
8083
llvm::DenseMap<const ProtocolDecl *, ProtocolInfo> Info;
8184
std::vector<const ProtocolDecl *> Protocols;
8285
bool Debug = false;
8386

87+
public:
8488
void visitRequirements(ArrayRef<Requirement> reqs);
8589

8690
bool isKnownProtocol(const ProtocolDecl *proto) const;
8791

92+
/// Returns the sorted list of protocols, with the property
93+
/// that (P refines Q) => P < Q. See compareProtocols()
94+
/// for details.
95+
ArrayRef<const ProtocolDecl *> getProtocols() const {
96+
return Protocols;
97+
}
98+
8899
const ProtocolInfo &getProtocolInfo(
89100
const ProtocolDecl *proto) const;
90101

102+
private:
91103
void addProtocol(const ProtocolDecl *proto);
92-
93104
void computeTransitiveClosure();
94-
95105
void computeLinearOrder();
96-
97106
void computeInheritedAssociatedTypes();
98-
99107
void computeInheritedProtocols();
100108

109+
public:
110+
void compute();
111+
101112
int compareProtocols(const ProtocolDecl *lhs,
102113
const ProtocolDecl *rhs) const;
103114

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,21 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
8686
// Collect all protocols transitively referenced from the generic signature's
8787
// requirements.
8888
Protocols.visitRequirements(sig->getRequirements());
89-
Protocols.computeTransitiveClosure();
90-
Protocols.computeLinearOrder();
91-
Protocols.computeInheritedProtocols();
92-
Protocols.computeInheritedAssociatedTypes();
89+
Protocols.compute();
9390

9491
// Add rewrite rules for each protocol.
95-
for (auto *proto : Protocols.Protocols) {
92+
for (auto *proto : Protocols.getProtocols()) {
9693
if (Debug) {
9794
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
9895
}
9996

10097
const auto &info = Protocols.getProtocolInfo(proto);
10198

102-
for (auto *type : info.AssociatedTypes)
103-
addAssociatedType(type, proto);
99+
for (auto *assocType : info.AssociatedTypes)
100+
addAssociatedType(assocType, proto);
101+
102+
for (auto *assocType : info.InheritedAssociatedTypes)
103+
addAssociatedType(assocType, proto);
104104

105105
for (auto req : info.Requirements)
106106
addRequirement(req.getCanonical(), proto);

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ Type getTypeForAtomRange(Iter begin, Iter end, Type root,
10511051
//
10521052
for (auto *proto : atom.getProtocols()) {
10531053
const auto &info = protos.getProtocolInfo(proto);
1054-
for (auto *otherAssocType : info.AssociatedTypes) {
1054+
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
10551055
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
10561056

10571057
if (otherAssocType->getName() == name &&
@@ -1060,6 +1060,14 @@ Type getTypeForAtomRange(Iter begin, Iter end, Type root,
10601060
assocType->getProtocol()) < 0)) {
10611061
assocType = otherAssocType;
10621062
}
1063+
};
1064+
1065+
for (auto *otherAssocType : info.AssociatedTypes) {
1066+
checkOtherAssocType(otherAssocType);
1067+
}
1068+
1069+
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
1070+
checkOtherAssocType(otherAssocType);
10631071
}
10641072
}
10651073
}

0 commit comments

Comments
 (0)