Skip to content

Commit f69b883

Browse files
committed
RequirementMachine: Topological order for protocol declarations
1 parent 793474f commit f69b883

File tree

3 files changed

+60
-40
lines changed

3 files changed

+60
-40
lines changed

include/swift/AST/RewriteSystem.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ namespace swift {
2727

2828
namespace rewriting {
2929

30+
using ProtocolOrder = std::function<int (const ProtocolDecl *,
31+
const ProtocolDecl *)>;
32+
3033
class Atom final {
3134
using Storage = llvm::PointerUnion<Identifier, GenericTypeParamType *>;
3235

@@ -103,7 +106,7 @@ class Atom final {
103106
return Value.get<GenericTypeParamType *>();
104107
}
105108

106-
int compare(Atom other) const;
109+
int compare(Atom other, ProtocolOrder compare) const;
107110

108111
void dump(llvm::raw_ostream &out) const;
109112

@@ -129,7 +132,7 @@ class Term final {
129132
Atoms.push_back(atom);
130133
}
131134

132-
int compare(const Term &other) const;
135+
int compare(const Term &other, ProtocolOrder order) const;
133136

134137
size_t size() const { return Atoms.size(); }
135138

@@ -165,9 +168,7 @@ class Rule final {
165168

166169
public:
167170
Rule(const Term &lhs, const Term &rhs)
168-
: LHS(lhs), RHS(rhs), deleted(false) {
169-
assert(LHS.compare(RHS) > 0);
170-
}
171+
: LHS(lhs), RHS(rhs), deleted(false) {}
171172

172173
bool apply(Term &term) const {
173174
assert(!deleted);
@@ -196,9 +197,17 @@ class Rule final {
196197

197198
class RewriteSystem final {
198199
std::vector<Rule> Rules;
200+
ProtocolOrder Order;
199201
bool Debug = false;
200202

201203
public:
204+
explicit RewriteSystem(ProtocolOrder order) : Order(order) {}
205+
206+
RewriteSystem(const RewriteSystem &) = delete;
207+
RewriteSystem(RewriteSystem &&) = delete;
208+
RewriteSystem &operator=(const RewriteSystem &) = delete;
209+
RewriteSystem &operator=(RewriteSystem &&) = delete;
210+
202211
bool addRule(Term lhs, Term rhs);
203212

204213
bool simplify(Term &term) const;

lib/AST/RequirementMachine.cpp

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ struct ProtocolGraph {
7272
}
7373
}
7474

75+
const ProtocolInfo &getProtocolInfo(
76+
const ProtocolDecl *proto) const {
77+
auto found = Info.find(proto);
78+
assert(found != Info.end());
79+
return found->second;
80+
}
81+
7582
void addProtocol(const ProtocolDecl *proto) {
7683
if (Info.count(proto) > 0)
7784
return;
@@ -86,7 +93,7 @@ struct ProtocolGraph {
8693
unsigned i = 0;
8794
while (i < Protocols.size()) {
8895
auto *proto = Protocols[i++];
89-
visitRequirements(Info[proto].Requirements);
96+
visitRequirements(getProtocolInfo(proto).Requirements);
9097
}
9198
}
9299

@@ -99,8 +106,8 @@ struct ProtocolGraph {
99106
Protocols.begin(), Protocols.end(),
100107
[&](const ProtocolDecl *lhs,
101108
const ProtocolDecl *rhs) -> int {
102-
const auto &lhsInfo = Info[lhs];
103-
const auto &rhsInfo = Info[rhs];
109+
const auto &lhsInfo = getProtocolInfo(lhs);
110+
const auto &rhsInfo = getProtocolInfo(rhs);
104111

105112
// protocol Base {} // depth 1
106113
// protocol Derived : Base {} // depth 2
@@ -126,7 +133,7 @@ struct ProtocolGraph {
126133
if (inherited == proto)
127134
continue;
128135

129-
for (auto *inheritedType : Info[inherited].AssociatedTypes) {
136+
for (auto *inheritedType : getProtocolInfo(inherited).AssociatedTypes) {
130137
if (!visited.insert(inheritedType).second)
131138
continue;
132139

@@ -164,12 +171,12 @@ struct ProtocolGraph {
164171
}
165172
};
166173

167-
class RewriteSystemBuilder {
174+
struct RewriteSystemBuilder {
168175
ASTContext &Context;
169176

177+
ProtocolGraph Protocols;
170178
std::vector<std::pair<Term, Term>> Rules;
171179

172-
public:
173180
RewriteSystemBuilder(ASTContext &ctx) : Context(ctx) {}
174181
void addGenericSignature(CanGenericSignature sig);
175182
void addAssociatedType(const AssociatedTypeDecl *type,
@@ -179,31 +186,25 @@ class RewriteSystemBuilder {
179186
const ProtocolDecl *proto);
180187
void addRequirement(const Requirement &req,
181188
const ProtocolDecl *proto);
182-
183-
void addRulesToRewriteSystem(RewriteSystem &system);
184189
};
185190

186191
} // end namespace
187192

188193
void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
189-
ProtocolGraph graph;
190-
graph.visitRequirements(sig->getRequirements());
191-
graph.computeTransitiveClosure();
192-
graph.computeLinearOrder();
193-
graph.computeInheritedAssociatedTypes();
194-
195-
for (auto *proto : graph.Protocols) {
196-
if (Context.LangOpts.DebugRequirementMachine) {
197-
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
198-
}
194+
Protocols.visitRequirements(sig->getRequirements());
195+
Protocols.computeTransitiveClosure();
196+
Protocols.computeLinearOrder();
197+
Protocols.computeInheritedAssociatedTypes();
199198

200-
const auto &info = graph.Info[proto];
199+
for (auto *proto : Protocols.Protocols) {
200+
const auto &info = Protocols.getProtocolInfo(proto);
201201

202202
for (auto *type : info.AssociatedTypes)
203203
addAssociatedType(type, proto);
204204

205205
for (auto *inherited : info.Inherited) {
206-
for (auto *inheritedType : graph.Info[inherited].AssociatedTypes) {
206+
auto inheritedTypes = Protocols.getProtocolInfo(inherited).AssociatedTypes;
207+
for (auto *inheritedType : inheritedTypes) {
207208
addInheritedAssociatedType(inheritedType, inherited, proto);
208209
}
209210
}
@@ -286,12 +287,6 @@ void RewriteSystemBuilder::addRequirement(const Requirement &req,
286287
}
287288
}
288289

289-
void RewriteSystemBuilder::addRulesToRewriteSystem(RewriteSystem &system) {
290-
for (auto rule : Rules) {
291-
system.addRule(rule.first, rule.second);
292-
}
293-
}
294-
295290
Term swift::rewriting::getTermForType(CanType paramType,
296291
const ProtocolDecl *proto) {
297292
assert(paramType->isTypeParameter());
@@ -314,8 +309,22 @@ Term swift::rewriting::getTermForType(CanType paramType,
314309
}
315310

316311
struct RequirementMachine::Implementation {
312+
ProtocolGraph Protocols;
313+
ProtocolOrder Order;
317314
RewriteSystem System;
318315
bool Complete = false;
316+
317+
Implementation()
318+
: Order([&](const ProtocolDecl *lhs,
319+
const ProtocolDecl *rhs) -> int {
320+
auto infoLHS = Protocols.Info.find(lhs);
321+
assert(infoLHS != Protocols.Info.end());
322+
auto infoRHS = Protocols.Info.find(rhs);
323+
assert(infoRHS != Protocols.Info.end());
324+
325+
return infoRHS->second.Index - infoLHS->second.Index;
326+
}),
327+
System(Order) {}
319328
};
320329

321330
RequirementMachine::RequirementMachine(ASTContext &ctx) : Context(ctx) {
@@ -336,10 +345,13 @@ void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
336345
RewriteSystemBuilder builder(Context);
337346
builder.addGenericSignature(sig);
338347

339-
builder.addRulesToRewriteSystem(Impl->System);
348+
Impl->Protocols = builder.Protocols;
349+
350+
for (const auto &rule : builder.Rules)
351+
Impl->System.addRule(rule.first, rule.second);
340352

341353
// FIXME: Add command line flag
342-
Impl->System.computeConfluentCompletion(1000);
354+
Impl->System.computeConfluentCompletion(10000);
343355

344356
markComplete();
345357

lib/AST/RewriteSystem.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
using namespace swift;
1919
using namespace rewriting;
2020

21-
int Atom::compare(Atom other) const {
21+
int Atom::compare(Atom other, ProtocolOrder protocolOrder) const {
2222
auto kind = getKind();
2323
auto otherKind = other.getKind();
2424

@@ -30,10 +30,10 @@ int Atom::compare(Atom other) const {
3030
return getName().compare(other.getName());
3131

3232
case Kind::Protocol:
33-
return TypeDecl::compare(getProtocol(), other.getProtocol());
33+
return protocolOrder(getProtocol(), other.getProtocol());
3434

3535
case Kind::AssociatedType: {
36-
int result = TypeDecl::compare(getProtocol(), other.getProtocol());
36+
int result = protocolOrder(getProtocol(), other.getProtocol());
3737
if (result)
3838
return result;
3939

@@ -81,15 +81,15 @@ void Atom::dump(llvm::raw_ostream &out) const {
8181
llvm_unreachable("Bad atom kind");
8282
}
8383

84-
int Term::compare(const Term &other) const {
84+
int Term::compare(const Term &other, ProtocolOrder protocolOrder) const {
8585
if (size() != other.size())
8686
return size() < other.size() ? -1 : 1;
8787

8888
for (unsigned i = 0, e = size(); i < e; ++i) {
8989
auto lhs = (*this)[i];
9090
auto rhs = other[i];
9191

92-
int result = lhs.compare(rhs);
92+
int result = lhs.compare(rhs, protocolOrder);
9393
if (result != 0)
9494
return result;
9595
}
@@ -120,7 +120,6 @@ bool Term::rewriteSubTerm(const Term &lhs, const Term &rhs) {
120120

121121
auto oldSize = size();
122122

123-
assert(rhs.compare(lhs) < 0);
124123
assert(rhs.size() <= lhs.size());
125124

126125
auto newIter = std::copy(rhs.begin(), rhs.end(), found);
@@ -196,7 +195,7 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
196195
simplify(lhs);
197196
simplify(rhs);
198197

199-
int result = lhs.compare(rhs);
198+
int result = lhs.compare(rhs, Order);
200199
if (result == 0)
201200
return false;
202201
if (result < 0)

0 commit comments

Comments
 (0)