Skip to content

Commit 69b7a64

Browse files
committed
RequirementMachine: Split off RewriteSystemBuilder
1 parent 6008e6a commit 69b7a64

File tree

3 files changed

+145
-94
lines changed

3 files changed

+145
-94
lines changed

include/swift/AST/RequirementMachine.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ class ProtocolDecl;
2424
class Requirement;
2525

2626
namespace rewriting {
27-
class Term;
28-
}
27+
28+
class Term;
29+
30+
Term getTermForType(CanType paramType, const ProtocolDecl *proto);
31+
32+
} // end namespace rewriting
2933

3034
class RequirementMachine final {
3135
friend class ASTContext;
@@ -43,18 +47,10 @@ class RequirementMachine final {
4347
RequirementMachine &operator=(RequirementMachine &&) = delete;
4448

4549
void addGenericSignature(CanGenericSignature sig);
46-
void addProtocolRequirementSignature(const ProtocolDecl *proto);
47-
void addRequirement(const Requirement &req, const ProtocolDecl *proto);
48-
void addAssociatedType(const AssociatedTypeDecl *type,
49-
const ProtocolDecl *proto);
50-
void processWorklist();
5150

5251
bool isComplete() const;
5352
void markComplete();
5453

55-
rewriting::Term getTermForType(CanType paramType,
56-
const ProtocolDecl *proto) const;
57-
5854
public:
5955
~RequirementMachine();
6056
};

lib/AST/RequirementMachine.cpp

Lines changed: 135 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,60 +18,110 @@
1818
#include "swift/AST/Requirement.h"
1919
#include "swift/AST/RewriteSystem.h"
2020
#include "llvm/ADT/DenseSet.h"
21+
#include "llvm/ADT/TinyPtrVector.h"
2122
#include <vector>
2223

2324
using namespace swift;
2425
using namespace rewriting;
2526

26-
struct RequirementMachine::Implementation {
27-
llvm::DenseSet<const ProtocolDecl *> VisitedProtocols;
28-
std::vector<const ProtocolDecl *> Worklist;
29-
RewriteSystem System;
30-
bool Complete = false;
27+
namespace {
28+
29+
struct ProtocolInfo {
30+
ArrayRef<ProtocolDecl *> Inherited;
31+
llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
32+
ArrayRef<Requirement> Requirements;
3133
};
3234

33-
RequirementMachine::RequirementMachine(ASTContext &ctx)
34-
: Context(ctx) {
35-
Impl = new Implementation();
36-
}
35+
struct ProtocolGraph {
36+
llvm::DenseMap<const ProtocolDecl *, ProtocolInfo> Info;
37+
std::vector<const ProtocolDecl *> Protocols;
3738

38-
RequirementMachine::~RequirementMachine() {
39-
delete Impl;
40-
}
39+
void visitRequirements(ArrayRef<Requirement> reqs) {
40+
for (auto req : reqs) {
41+
if (req.getKind() == RequirementKind::Conformance) {
42+
addProtocol(req.getProtocolDecl());
43+
}
44+
}
45+
}
4146

42-
void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
43-
PrettyStackTraceGenericSignature debugStack("building rewrite system for", sig);
47+
void addProtocol(const ProtocolDecl *proto) {
48+
if (Info.count(proto) > 0)
49+
return;
4450

45-
if (Context.LangOpts.DebugRequirementMachine) {
46-
llvm::dbgs() << "Adding generic signature " << sig << " {\n";
51+
Info[proto] = {proto->getInheritedProtocols(),
52+
proto->getAssociatedTypeMembers(),
53+
proto->getRequirementSignature()};
54+
Protocols.push_back(proto);
4755
}
4856

49-
for (const auto &req : sig->getRequirements())
50-
addRequirement(req, /*proto=*/nullptr);
57+
void computeTransitiveClosure() {
58+
unsigned i = 0;
59+
while (i < Protocols.size()) {
60+
auto *proto = Protocols[i++];
61+
visitRequirements(Info[proto].Requirements);
62+
}
63+
}
64+
};
5165

52-
processWorklist();
66+
class RewriteSystemBuilder {
67+
ASTContext &Context;
5368

54-
// FIXME: Add command line flag
55-
Impl->System.computeConfluentCompletion(100);
69+
std::vector<std::pair<Term, Term>> Rules;
5670

57-
markComplete();
71+
public:
72+
RewriteSystemBuilder(ASTContext &ctx) : Context(ctx) {}
73+
void addGenericSignature(CanGenericSignature sig);
74+
void addAssociatedType(const AssociatedTypeDecl *type,
75+
const ProtocolDecl *proto);
76+
void addRequirement(const Requirement &req,
77+
const ProtocolDecl *proto);
5878

59-
if (Context.LangOpts.DebugRequirementMachine) {
60-
llvm::dbgs() << "}\n";
79+
void addRulesToRewriteSystem(RewriteSystem &system);
80+
};
81+
82+
} // end namespace
83+
84+
void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
85+
ProtocolGraph graph;
86+
graph.visitRequirements(sig->getRequirements());
87+
graph.computeTransitiveClosure();
88+
89+
for (auto *proto : graph.Protocols) {
90+
if (Context.LangOpts.DebugRequirementMachine) {
91+
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
92+
}
93+
94+
const auto &info = graph.Info[proto];
95+
96+
for (auto *type : info.AssociatedTypes)
97+
addAssociatedType(type, proto);
98+
99+
for (auto req : info.Requirements)
100+
addRequirement(req.getCanonical(), proto);
101+
102+
if (Context.LangOpts.DebugRequirementMachine) {
103+
llvm::dbgs() << "}\n";
104+
}
61105
}
106+
107+
for (const auto &req : sig->getRequirements())
108+
addRequirement(req, /*proto=*/nullptr);
62109
}
63110

64-
void RequirementMachine::addProtocolRequirementSignature(
65-
const ProtocolDecl *proto) {
66-
auto inserted = Impl->VisitedProtocols.insert(proto);
67-
if (!inserted.second)
68-
return;
111+
void RewriteSystemBuilder::addAssociatedType(const AssociatedTypeDecl *type,
112+
const ProtocolDecl *proto) {
113+
Term lhs;
114+
lhs.add(Atom::forProtocol(proto));
115+
lhs.add(Atom::forName(type->getName()));
69116

70-
Impl->Worklist.push_back(proto);
117+
Term rhs;
118+
rhs.add(Atom::forAssociatedType(type));
119+
120+
Rules.emplace_back(lhs, rhs);
71121
}
72122

73-
void RequirementMachine::addRequirement(const Requirement &req,
74-
const ProtocolDecl *proto) {
123+
void RewriteSystemBuilder::addRequirement(const Requirement &req,
124+
const ProtocolDecl *proto) {
75125
if (Context.LangOpts.DebugRequirementMachine) {
76126
llvm::dbgs() << "+ ";
77127
req.dump(llvm::dbgs());
@@ -88,9 +138,7 @@ void RequirementMachine::addRequirement(const Requirement &req,
88138
auto constraintTerm = subjectTerm;
89139
constraintTerm.add(Atom::forProtocol(proto));
90140

91-
Impl->System.addRule(subjectTerm, constraintTerm);
92-
93-
addProtocolRequirementSignature(proto);
141+
Rules.emplace_back(subjectTerm, constraintTerm);
94142
break;
95143
}
96144
case RequirementKind::Superclass:
@@ -106,45 +154,71 @@ void RequirementMachine::addRequirement(const Requirement &req,
106154

107155
auto otherTerm = getTermForType(otherType, proto);
108156

109-
Impl->System.addRule(subjectTerm, otherTerm);
157+
Rules.emplace_back(subjectTerm, otherTerm);
110158
break;
111159
}
112160
}
113161
}
114162

115-
void RequirementMachine::addAssociatedType(const AssociatedTypeDecl *type,
116-
const ProtocolDecl *proto) {
117-
Term lhs;
118-
lhs.add(Atom::forProtocol(proto));
119-
lhs.add(Atom::forName(type->getName()));
163+
void RewriteSystemBuilder::addRulesToRewriteSystem(RewriteSystem &system) {
164+
for (auto rule : Rules) {
165+
system.addRule(rule.first, rule.second);
166+
}
167+
}
120168

121-
Term rhs;
122-
rhs.add(Atom::forAssociatedType(type));
169+
Term swift::rewriting::getTermForType(CanType paramType,
170+
const ProtocolDecl *proto) {
171+
assert(paramType->isTypeParameter());
123172

124-
Impl->System.addRule(lhs, rhs);
173+
SmallVector<Atom, 3> atoms;
174+
while (auto memberType = dyn_cast<DependentMemberType>(paramType)) {
175+
atoms.push_back(Atom::forName(memberType->getName()));
176+
paramType = memberType.getBase();
177+
}
178+
179+
if (proto) {
180+
assert(proto->getSelfInterfaceType()->isEqual(paramType));
181+
atoms.push_back(Atom::forProtocol(proto));
182+
} else {
183+
atoms.push_back(Atom::forGenericParam(cast<GenericTypeParamType>(paramType)));
184+
}
185+
186+
std::reverse(atoms.begin(), atoms.end());
187+
return Term(atoms);
125188
}
126189

127-
void RequirementMachine::processWorklist() {
128-
while (!Impl->Worklist.empty()) {
129-
const auto *proto = Impl->Worklist.back();
130-
Impl->Worklist.pop_back();
190+
struct RequirementMachine::Implementation {
191+
RewriteSystem System;
192+
bool Complete = false;
193+
};
131194

132-
if (Context.LangOpts.DebugRequirementMachine) {
133-
llvm::dbgs() << "protocol "
134-
<< proto->getName() << " {\n";
135-
}
195+
RequirementMachine::RequirementMachine(ASTContext &ctx) : Context(ctx) {
196+
Impl = new Implementation();
197+
}
136198

137-
for (const auto *type : proto->getAssociatedTypeMembers()) {
138-
addAssociatedType(type, proto);
139-
}
199+
RequirementMachine::~RequirementMachine() {
200+
delete Impl;
201+
}
140202

141-
for (const auto &req : proto->getRequirementSignature()) {
142-
addRequirement(req.getCanonical(), proto);
143-
}
203+
void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
204+
PrettyStackTraceGenericSignature debugStack("building rewrite system for", sig);
144205

145-
if (Context.LangOpts.DebugRequirementMachine) {
146-
llvm::dbgs() << "}\n";
147-
}
206+
if (Context.LangOpts.DebugRequirementMachine) {
207+
llvm::dbgs() << "Adding generic signature " << sig << " {\n";
208+
}
209+
210+
RewriteSystemBuilder builder(Context);
211+
builder.addGenericSignature(sig);
212+
213+
builder.addRulesToRewriteSystem(Impl->System);
214+
215+
// FIXME: Add command line flag
216+
Impl->System.computeConfluentCompletion(1000);
217+
218+
markComplete();
219+
220+
if (Context.LangOpts.DebugRequirementMachine) {
221+
llvm::dbgs() << "}\n";
148222
}
149223
}
150224

@@ -158,25 +232,4 @@ void RequirementMachine::markComplete() {
158232
}
159233
assert(!Impl->Complete);
160234
Impl->Complete = true;
161-
}
162-
163-
Term RequirementMachine::getTermForType(CanType paramType,
164-
const ProtocolDecl *proto) const {
165-
assert(paramType->isTypeParameter());
166-
167-
SmallVector<Atom, 3> atoms;
168-
while (auto memberType = dyn_cast<DependentMemberType>(paramType)) {
169-
atoms.push_back(Atom::forName(memberType->getName()));
170-
paramType = memberType.getBase();
171-
}
172-
173-
if (proto) {
174-
assert(proto->getSelfInterfaceType()->isEqual(paramType));
175-
atoms.push_back(Atom::forProtocol(proto));
176-
} else {
177-
atoms.push_back(Atom::forGenericParam(cast<GenericTypeParamType>(paramType)));
178-
}
179-
180-
std::reverse(atoms.begin(), atoms.end());
181-
return Term(atoms);
182235
}

lib/AST/RewriteSystem.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,10 @@ bool RewriteSystem::simplify(Term &term) const {
212212
if (rule.isDeleted())
213213
continue;
214214

215-
if (rule.apply(term))
215+
if (rule.apply(term)) {
216+
changed = true;
216217
tryAgain = true;
218+
}
217219
}
218220

219221
if (!tryAgain)
@@ -297,7 +299,7 @@ void RewriteSystem::computeConfluentCompletion(
297299
}
298300

299301
void RewriteSystem::dump(llvm::raw_ostream &out) const {
300-
out << "Rewrite system: {";
302+
out << "Rewrite system: {\n";
301303
for (const auto &rule : Rules) {
302304
out << "- ";
303305
rule.dump(out);

0 commit comments

Comments
 (0)