@@ -72,6 +72,13 @@ struct ProtocolGraph {
72
72
}
73
73
}
74
74
75
+ const ProtocolInfo &getProtocolInfo (
76
+ const ProtocolDecl *proto) const {
77
+ auto found = Info.find (proto);
78
+ assert (found != Info.end ());
79
+ return found->second ;
80
+ }
81
+
75
82
void addProtocol (const ProtocolDecl *proto) {
76
83
if (Info.count (proto) > 0 )
77
84
return ;
@@ -86,7 +93,7 @@ struct ProtocolGraph {
86
93
unsigned i = 0 ;
87
94
while (i < Protocols.size ()) {
88
95
auto *proto = Protocols[i++];
89
- visitRequirements (Info[ proto] .Requirements );
96
+ visitRequirements (getProtocolInfo ( proto) .Requirements );
90
97
}
91
98
}
92
99
@@ -99,8 +106,8 @@ struct ProtocolGraph {
99
106
Protocols.begin (), Protocols.end (),
100
107
[&](const ProtocolDecl *lhs,
101
108
const ProtocolDecl *rhs) -> int {
102
- const auto &lhsInfo = Info[ lhs] ;
103
- const auto &rhsInfo = Info[ rhs] ;
109
+ const auto &lhsInfo = getProtocolInfo ( lhs) ;
110
+ const auto &rhsInfo = getProtocolInfo ( rhs) ;
104
111
105
112
// protocol Base {} // depth 1
106
113
// protocol Derived : Base {} // depth 2
@@ -126,7 +133,7 @@ struct ProtocolGraph {
126
133
if (inherited == proto)
127
134
continue ;
128
135
129
- for (auto *inheritedType : Info[ inherited] .AssociatedTypes ) {
136
+ for (auto *inheritedType : getProtocolInfo ( inherited) .AssociatedTypes ) {
130
137
if (!visited.insert (inheritedType).second )
131
138
continue ;
132
139
@@ -164,12 +171,12 @@ struct ProtocolGraph {
164
171
}
165
172
};
166
173
167
- class RewriteSystemBuilder {
174
+ struct RewriteSystemBuilder {
168
175
ASTContext &Context;
169
176
177
+ ProtocolGraph Protocols;
170
178
std::vector<std::pair<Term, Term>> Rules;
171
179
172
- public:
173
180
RewriteSystemBuilder (ASTContext &ctx) : Context(ctx) {}
174
181
void addGenericSignature (CanGenericSignature sig);
175
182
void addAssociatedType (const AssociatedTypeDecl *type,
@@ -179,31 +186,25 @@ class RewriteSystemBuilder {
179
186
const ProtocolDecl *proto);
180
187
void addRequirement (const Requirement &req,
181
188
const ProtocolDecl *proto);
182
-
183
- void addRulesToRewriteSystem (RewriteSystem &system);
184
189
};
185
190
186
191
} // end namespace
187
192
188
193
void RewriteSystemBuilder::addGenericSignature (CanGenericSignature sig) {
189
- ProtocolGraph graph;
190
- graph.visitRequirements (sig->getRequirements ());
191
- graph.computeTransitiveClosure ();
192
- graph.computeLinearOrder ();
193
- graph.computeInheritedAssociatedTypes ();
194
-
195
- for (auto *proto : graph.Protocols ) {
196
- if (Context.LangOpts .DebugRequirementMachine ) {
197
- llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
198
- }
194
+ Protocols.visitRequirements (sig->getRequirements ());
195
+ Protocols.computeTransitiveClosure ();
196
+ Protocols.computeLinearOrder ();
197
+ Protocols.computeInheritedAssociatedTypes ();
199
198
200
- const auto &info = graph.Info [proto];
199
+ for (auto *proto : Protocols.Protocols ) {
200
+ const auto &info = Protocols.getProtocolInfo (proto);
201
201
202
202
for (auto *type : info.AssociatedTypes )
203
203
addAssociatedType (type, proto);
204
204
205
205
for (auto *inherited : info.Inherited ) {
206
- for (auto *inheritedType : graph.Info [inherited].AssociatedTypes ) {
206
+ auto inheritedTypes = Protocols.getProtocolInfo (inherited).AssociatedTypes ;
207
+ for (auto *inheritedType : inheritedTypes) {
207
208
addInheritedAssociatedType (inheritedType, inherited, proto);
208
209
}
209
210
}
@@ -286,12 +287,6 @@ void RewriteSystemBuilder::addRequirement(const Requirement &req,
286
287
}
287
288
}
288
289
289
- void RewriteSystemBuilder::addRulesToRewriteSystem (RewriteSystem &system) {
290
- for (auto rule : Rules) {
291
- system.addRule (rule.first , rule.second );
292
- }
293
- }
294
-
295
290
Term swift::rewriting::getTermForType (CanType paramType,
296
291
const ProtocolDecl *proto) {
297
292
assert (paramType->isTypeParameter ());
@@ -314,8 +309,22 @@ Term swift::rewriting::getTermForType(CanType paramType,
314
309
}
315
310
316
311
struct RequirementMachine ::Implementation {
312
+ ProtocolGraph Protocols;
313
+ ProtocolOrder Order;
317
314
RewriteSystem System;
318
315
bool Complete = false ;
316
+
317
+ Implementation ()
318
+ : Order([&](const ProtocolDecl *lhs,
319
+ const ProtocolDecl *rhs) -> int {
320
+ auto infoLHS = Protocols.Info .find (lhs);
321
+ assert (infoLHS != Protocols.Info .end ());
322
+ auto infoRHS = Protocols.Info .find (rhs);
323
+ assert (infoRHS != Protocols.Info .end ());
324
+
325
+ return infoRHS->second .Index - infoLHS->second .Index ;
326
+ }),
327
+ System (Order) {}
319
328
};
320
329
321
330
RequirementMachine::RequirementMachine (ASTContext &ctx) : Context(ctx) {
@@ -336,10 +345,13 @@ void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
336
345
RewriteSystemBuilder builder (Context);
337
346
builder.addGenericSignature (sig);
338
347
339
- builder.addRulesToRewriteSystem (Impl->System );
348
+ Impl->Protocols = builder.Protocols ;
349
+
350
+ for (const auto &rule : builder.Rules )
351
+ Impl->System .addRule (rule.first , rule.second );
340
352
341
353
// FIXME: Add command line flag
342
- Impl->System .computeConfluentCompletion (1000 );
354
+ Impl->System .computeConfluentCompletion (10000 );
343
355
344
356
markComplete ();
345
357
0 commit comments