18
18
#include " swift/AST/Requirement.h"
19
19
#include " swift/AST/RewriteSystem.h"
20
20
#include " llvm/ADT/DenseSet.h"
21
+ #include " llvm/ADT/TinyPtrVector.h"
21
22
#include < vector>
22
23
23
24
using namespace swift ;
24
25
using namespace rewriting ;
25
26
26
- struct RequirementMachine ::Implementation {
27
- llvm::DenseSet<const ProtocolDecl *> VisitedProtocols;
28
- std::vector<const ProtocolDecl *> Worklist;
29
- RewriteSystem System;
30
- bool Complete = false ;
27
+ namespace {
28
+
29
+ struct ProtocolInfo {
30
+ ArrayRef<ProtocolDecl *> Inherited;
31
+ llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
32
+ ArrayRef<Requirement> Requirements;
31
33
};
32
34
33
- RequirementMachine::RequirementMachine (ASTContext &ctx)
34
- : Context(ctx) {
35
- Impl = new Implementation ();
36
- }
35
+ struct ProtocolGraph {
36
+ llvm::DenseMap<const ProtocolDecl *, ProtocolInfo> Info;
37
+ std::vector<const ProtocolDecl *> Protocols;
37
38
38
- RequirementMachine::~RequirementMachine () {
39
- delete Impl;
40
- }
39
+ void visitRequirements (ArrayRef<Requirement> reqs) {
40
+ for (auto req : reqs) {
41
+ if (req.getKind () == RequirementKind::Conformance) {
42
+ addProtocol (req.getProtocolDecl ());
43
+ }
44
+ }
45
+ }
41
46
42
- void RequirementMachine::addGenericSignature (CanGenericSignature sig) {
43
- PrettyStackTraceGenericSignature debugStack (" building rewrite system for" , sig);
47
+ void addProtocol (const ProtocolDecl *proto) {
48
+ if (Info.count (proto) > 0 )
49
+ return ;
44
50
45
- if (Context.LangOpts .DebugRequirementMachine ) {
46
- llvm::dbgs () << " Adding generic signature " << sig << " {\n " ;
51
+ Info[proto] = {proto->getInheritedProtocols (),
52
+ proto->getAssociatedTypeMembers (),
53
+ proto->getRequirementSignature ()};
54
+ Protocols.push_back (proto);
47
55
}
48
56
49
- for (const auto &req : sig->getRequirements ())
50
- addRequirement (req, /* proto=*/ nullptr );
57
+ void computeTransitiveClosure () {
58
+ unsigned i = 0 ;
59
+ while (i < Protocols.size ()) {
60
+ auto *proto = Protocols[i++];
61
+ visitRequirements (Info[proto].Requirements );
62
+ }
63
+ }
64
+ };
51
65
52
- processWorklist ();
66
+ class RewriteSystemBuilder {
67
+ ASTContext &Context;
53
68
54
- // FIXME: Add command line flag
55
- Impl->System .computeConfluentCompletion (100 );
69
+ std::vector<std::pair<Term, Term>> Rules;
56
70
57
- markComplete ();
71
+ public:
72
+ RewriteSystemBuilder (ASTContext &ctx) : Context(ctx) {}
73
+ void addGenericSignature (CanGenericSignature sig);
74
+ void addAssociatedType (const AssociatedTypeDecl *type,
75
+ const ProtocolDecl *proto);
76
+ void addRequirement (const Requirement &req,
77
+ const ProtocolDecl *proto);
58
78
59
- if (Context.LangOpts .DebugRequirementMachine ) {
60
- llvm::dbgs () << " }\n " ;
79
+ void addRulesToRewriteSystem (RewriteSystem &system);
80
+ };
81
+
82
+ } // end namespace
83
+
84
+ void RewriteSystemBuilder::addGenericSignature (CanGenericSignature sig) {
85
+ ProtocolGraph graph;
86
+ graph.visitRequirements (sig->getRequirements ());
87
+ graph.computeTransitiveClosure ();
88
+
89
+ for (auto *proto : graph.Protocols ) {
90
+ if (Context.LangOpts .DebugRequirementMachine ) {
91
+ llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
92
+ }
93
+
94
+ const auto &info = graph.Info [proto];
95
+
96
+ for (auto *type : info.AssociatedTypes )
97
+ addAssociatedType (type, proto);
98
+
99
+ for (auto req : info.Requirements )
100
+ addRequirement (req.getCanonical (), proto);
101
+
102
+ if (Context.LangOpts .DebugRequirementMachine ) {
103
+ llvm::dbgs () << " }\n " ;
104
+ }
61
105
}
106
+
107
+ for (const auto &req : sig->getRequirements ())
108
+ addRequirement (req, /* proto=*/ nullptr );
62
109
}
63
110
64
- void RequirementMachine::addProtocolRequirementSignature (
65
- const ProtocolDecl *proto) {
66
- auto inserted = Impl-> VisitedProtocols . insert (proto) ;
67
- if (!inserted. second )
68
- return ;
111
+ void RewriteSystemBuilder::addAssociatedType ( const AssociatedTypeDecl *type,
112
+ const ProtocolDecl *proto) {
113
+ Term lhs ;
114
+ lhs. add ( Atom::forProtocol (proto));
115
+ lhs. add ( Atom::forName (type-> getName ())) ;
69
116
70
- Impl->Worklist .push_back (proto);
117
+ Term rhs;
118
+ rhs.add (Atom::forAssociatedType (type));
119
+
120
+ Rules.emplace_back (lhs, rhs);
71
121
}
72
122
73
- void RequirementMachine ::addRequirement (const Requirement &req,
74
- const ProtocolDecl *proto) {
123
+ void RewriteSystemBuilder ::addRequirement (const Requirement &req,
124
+ const ProtocolDecl *proto) {
75
125
if (Context.LangOpts .DebugRequirementMachine ) {
76
126
llvm::dbgs () << " + " ;
77
127
req.dump (llvm::dbgs ());
@@ -88,9 +138,7 @@ void RequirementMachine::addRequirement(const Requirement &req,
88
138
auto constraintTerm = subjectTerm;
89
139
constraintTerm.add (Atom::forProtocol (proto));
90
140
91
- Impl->System .addRule (subjectTerm, constraintTerm);
92
-
93
- addProtocolRequirementSignature (proto);
141
+ Rules.emplace_back (subjectTerm, constraintTerm);
94
142
break ;
95
143
}
96
144
case RequirementKind::Superclass:
@@ -106,45 +154,71 @@ void RequirementMachine::addRequirement(const Requirement &req,
106
154
107
155
auto otherTerm = getTermForType (otherType, proto);
108
156
109
- Impl-> System . addRule (subjectTerm, otherTerm);
157
+ Rules. emplace_back (subjectTerm, otherTerm);
110
158
break ;
111
159
}
112
160
}
113
161
}
114
162
115
- void RequirementMachine::addAssociatedType ( const AssociatedTypeDecl *type,
116
- const ProtocolDecl *proto ) {
117
- Term lhs ;
118
- lhs. add ( Atom::forProtocol (proto));
119
- lhs. add ( Atom::forName (type-> getName ()));
163
+ void RewriteSystemBuilder::addRulesToRewriteSystem (RewriteSystem &system) {
164
+ for ( auto rule : Rules ) {
165
+ system. addRule (rule. first , rule. second ) ;
166
+ }
167
+ }
120
168
121
- Term rhs;
122
- rhs.add (Atom::forAssociatedType (type));
169
+ Term swift::rewriting::getTermForType (CanType paramType,
170
+ const ProtocolDecl *proto) {
171
+ assert (paramType->isTypeParameter ());
123
172
124
- Impl->System .addRule (lhs, rhs);
173
+ SmallVector<Atom, 3 > atoms;
174
+ while (auto memberType = dyn_cast<DependentMemberType>(paramType)) {
175
+ atoms.push_back (Atom::forName (memberType->getName ()));
176
+ paramType = memberType.getBase ();
177
+ }
178
+
179
+ if (proto) {
180
+ assert (proto->getSelfInterfaceType ()->isEqual (paramType));
181
+ atoms.push_back (Atom::forProtocol (proto));
182
+ } else {
183
+ atoms.push_back (Atom::forGenericParam (cast<GenericTypeParamType>(paramType)));
184
+ }
185
+
186
+ std::reverse (atoms.begin (), atoms.end ());
187
+ return Term (atoms);
125
188
}
126
189
127
- void RequirementMachine::processWorklist () {
128
- while (!Impl-> Worklist . empty ()) {
129
- const auto *proto = Impl-> Worklist . back () ;
130
- Impl-> Worklist . pop_back () ;
190
+ struct RequirementMachine ::Implementation {
191
+ RewriteSystem System;
192
+ bool Complete = false ;
193
+ } ;
131
194
132
- if (Context.LangOpts .DebugRequirementMachine ) {
133
- llvm::dbgs () << " protocol "
134
- << proto->getName () << " {\n " ;
135
- }
195
+ RequirementMachine::RequirementMachine (ASTContext &ctx) : Context(ctx) {
196
+ Impl = new Implementation ();
197
+ }
136
198
137
- for ( const auto *type : proto-> getAssociatedTypeMembers () ) {
138
- addAssociatedType (type, proto) ;
139
- }
199
+ RequirementMachine::~RequirementMachine ( ) {
200
+ delete Impl ;
201
+ }
140
202
141
- for (const auto &req : proto->getRequirementSignature ()) {
142
- addRequirement (req.getCanonical (), proto);
143
- }
203
+ void RequirementMachine::addGenericSignature (CanGenericSignature sig) {
204
+ PrettyStackTraceGenericSignature debugStack (" building rewrite system for" , sig);
144
205
145
- if (Context.LangOpts .DebugRequirementMachine ) {
146
- llvm::dbgs () << " }\n " ;
147
- }
206
+ if (Context.LangOpts .DebugRequirementMachine ) {
207
+ llvm::dbgs () << " Adding generic signature " << sig << " {\n " ;
208
+ }
209
+
210
+ RewriteSystemBuilder builder (Context);
211
+ builder.addGenericSignature (sig);
212
+
213
+ builder.addRulesToRewriteSystem (Impl->System );
214
+
215
+ // FIXME: Add command line flag
216
+ Impl->System .computeConfluentCompletion (1000 );
217
+
218
+ markComplete ();
219
+
220
+ if (Context.LangOpts .DebugRequirementMachine ) {
221
+ llvm::dbgs () << " }\n " ;
148
222
}
149
223
}
150
224
@@ -158,25 +232,4 @@ void RequirementMachine::markComplete() {
158
232
}
159
233
assert (!Impl->Complete );
160
234
Impl->Complete = true ;
161
- }
162
-
163
- Term RequirementMachine::getTermForType (CanType paramType,
164
- const ProtocolDecl *proto) const {
165
- assert (paramType->isTypeParameter ());
166
-
167
- SmallVector<Atom, 3 > atoms;
168
- while (auto memberType = dyn_cast<DependentMemberType>(paramType)) {
169
- atoms.push_back (Atom::forName (memberType->getName ()));
170
- paramType = memberType.getBase ();
171
- }
172
-
173
- if (proto) {
174
- assert (proto->getSelfInterfaceType ()->isEqual (paramType));
175
- atoms.push_back (Atom::forProtocol (proto));
176
- } else {
177
- atoms.push_back (Atom::forGenericParam (cast<GenericTypeParamType>(paramType)));
178
- }
179
-
180
- std::reverse (atoms.begin (), atoms.end ());
181
- return Term (atoms);
182
235
}
0 commit comments