@@ -30,6 +30,34 @@ struct ProtocolInfo {
30
30
ArrayRef<ProtocolDecl *> Inherited;
31
31
llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
32
32
ArrayRef<Requirement> Requirements;
33
+
34
+ // Used by computeDepth() to detect circularity.
35
+ unsigned Mark : 1 ;
36
+
37
+ // Longest chain of protocol refinements, including this one.
38
+ // Greater than zero on valid code, might be zero if there's
39
+ // a cycle.
40
+ unsigned Depth : 31 ;
41
+
42
+ // Index of the protocol in the linear order.
43
+ unsigned Index : 32 ;
44
+
45
+ ProtocolInfo () {
46
+ Mark = 0 ;
47
+ Depth = 0 ;
48
+ Index = 0 ;
49
+ }
50
+
51
+ ProtocolInfo (ArrayRef<ProtocolDecl *> inherited,
52
+ llvm::TinyPtrVector<AssociatedTypeDecl *> &&types,
53
+ ArrayRef<Requirement> reqs)
54
+ : Inherited(inherited),
55
+ AssociatedTypes (types),
56
+ Requirements(reqs) {
57
+ Mark = 0 ;
58
+ Depth = 0 ;
59
+ Index = 0 ;
60
+ }
33
61
};
34
62
35
63
struct ProtocolGraph {
@@ -61,6 +89,79 @@ struct ProtocolGraph {
61
89
visitRequirements (Info[proto].Requirements );
62
90
}
63
91
}
92
+
93
+ void computeLinearOrder () {
94
+ for (const auto *proto : Protocols) {
95
+ (void ) computeProtocolDepth (proto);
96
+ }
97
+
98
+ std::sort (
99
+ Protocols.begin (), Protocols.end (),
100
+ [&](const ProtocolDecl *lhs,
101
+ const ProtocolDecl *rhs) -> int {
102
+ const auto &lhsInfo = Info[lhs];
103
+ const auto &rhsInfo = Info[rhs];
104
+
105
+ // protocol Base {} // depth 1
106
+ // protocol Derived : Base {} // depth 2
107
+ //
108
+ // Derived < Base in the linear order.
109
+ if (lhsInfo.Depth != rhsInfo.Depth )
110
+ return lhsInfo.Depth - rhsInfo.Depth ;
111
+
112
+ return TypeDecl::compare (lhs, rhs);
113
+ });
114
+
115
+ for (unsigned i : indices (Protocols)) {
116
+ Info[Protocols[i]].Index = i;
117
+ }
118
+ }
119
+
120
+ void computeInheritedAssociatedTypes () {
121
+ for (const auto *proto : Protocols) {
122
+ auto &info = Info[proto];
123
+
124
+ llvm::SmallDenseSet<const AssociatedTypeDecl *, 4 > visited;
125
+ for (const auto *inherited : info.Inherited ) {
126
+ if (inherited == proto)
127
+ continue ;
128
+
129
+ for (auto *inheritedType : Info[inherited].AssociatedTypes ) {
130
+ if (!visited.insert (inheritedType).second )
131
+ continue ;
132
+
133
+ // The 'if (inherited == proto)' above avoids a potential
134
+ // iterator invalidation here.
135
+ info.AssociatedTypes .push_back (inheritedType);
136
+ }
137
+ }
138
+ }
139
+ }
140
+
141
+ private:
142
+ unsigned computeProtocolDepth (const ProtocolDecl *proto) {
143
+ auto &info = Info[proto];
144
+
145
+ if (info.Mark ) {
146
+ // Already computed, or we have a cycle. Cycles are diagnosed
147
+ // elsewhere in the type checker, so we don't have to do
148
+ // anything here.
149
+ return info.Depth ;
150
+ }
151
+
152
+ info.Mark = true ;
153
+ unsigned depth = 0 ;
154
+
155
+ for (auto *inherited : info.Inherited ) {
156
+ unsigned inheritedDepth = computeProtocolDepth (inherited);
157
+ depth = std::max (inheritedDepth, depth);
158
+ }
159
+
160
+ depth++;
161
+
162
+ info.Depth = depth;
163
+ return depth;
164
+ }
64
165
};
65
166
66
167
class RewriteSystemBuilder {
@@ -73,6 +174,9 @@ class RewriteSystemBuilder {
73
174
void addGenericSignature (CanGenericSignature sig);
74
175
void addAssociatedType (const AssociatedTypeDecl *type,
75
176
const ProtocolDecl *proto);
177
+ void addInheritedAssociatedType (const AssociatedTypeDecl *type,
178
+ const ProtocolDecl *inherited,
179
+ const ProtocolDecl *proto);
76
180
void addRequirement (const Requirement &req,
77
181
const ProtocolDecl *proto);
78
182
@@ -85,6 +189,8 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
85
189
ProtocolGraph graph;
86
190
graph.visitRequirements (sig->getRequirements ());
87
191
graph.computeTransitiveClosure ();
192
+ graph.computeLinearOrder ();
193
+ graph.computeInheritedAssociatedTypes ();
88
194
89
195
for (auto *proto : graph.Protocols ) {
90
196
if (Context.LangOpts .DebugRequirementMachine ) {
@@ -96,6 +202,12 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
96
202
for (auto *type : info.AssociatedTypes )
97
203
addAssociatedType (type, proto);
98
204
205
+ for (auto *inherited : info.Inherited ) {
206
+ for (auto *inheritedType : graph.Info [inherited].AssociatedTypes ) {
207
+ addInheritedAssociatedType (inheritedType, inherited, proto);
208
+ }
209
+ }
210
+
99
211
for (auto req : info.Requirements )
100
212
addRequirement (req.getCanonical (), proto);
101
213
@@ -115,7 +227,21 @@ void RewriteSystemBuilder::addAssociatedType(const AssociatedTypeDecl *type,
115
227
lhs.add (Atom::forName (type->getName ()));
116
228
117
229
Term rhs;
118
- rhs.add (Atom::forAssociatedType (type));
230
+ rhs.add (Atom::forAssociatedType (proto, type->getName ()));
231
+
232
+ Rules.emplace_back (lhs, rhs);
233
+ }
234
+
235
+ void RewriteSystemBuilder::addInheritedAssociatedType (
236
+ const AssociatedTypeDecl *type,
237
+ const ProtocolDecl *inherited,
238
+ const ProtocolDecl *proto) {
239
+ Term lhs;
240
+ lhs.add (Atom::forProtocol (proto));
241
+ lhs.add (Atom::forAssociatedType (inherited, type->getName ()));
242
+
243
+ Term rhs;
244
+ rhs.add (Atom::forAssociatedType (proto, type->getName ()));
119
245
120
246
Rules.emplace_back (lhs, rhs);
121
247
}
0 commit comments