Skip to content

Commit 194eb29

Browse files
committed
RequirementMachine: Move the protocol graph into the rewrite system
1 parent 85d17a6 commit 194eb29

File tree

5 files changed

+49
-38
lines changed

5 files changed

+49
-38
lines changed

include/swift/AST/ProtocolGraph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ struct ProtocolGraph {
7777

7878
void computeInheritedAssociatedTypes();
7979

80+
int compareProtocols(const ProtocolDecl *lhs,
81+
const ProtocolDecl *rhs) const;
82+
8083
private:
8184
unsigned computeProtocolDepth(const ProtocolDecl *proto);
8285
};

include/swift/AST/RewriteSystem.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "swift/AST/Decl.h"
1717
#include "swift/AST/Identifier.h"
1818
#include "swift/AST/LayoutConstraint.h"
19+
#include "swift/AST/ProtocolGraph.h"
1920
#include "swift/AST/Types.h"
2021
#include "llvm/ADT/PointerUnion.h"
2122
#include "llvm/ADT/SmallVector.h"
@@ -30,9 +31,6 @@ namespace swift {
3031

3132
namespace rewriting {
3233

33-
using ProtocolOrder = std::function<int (const ProtocolDecl *,
34-
const ProtocolDecl *)>;
35-
3634
class Atom final {
3735
using Storage = llvm::PointerUnion<Identifier,
3836
GenericTypeParamType *,
@@ -142,7 +140,7 @@ class Atom final {
142140
return Value.get<LayoutConstraint>();
143141
}
144142

145-
int compare(Atom other, ProtocolOrder compare) const;
143+
int compare(Atom other, const ProtocolGraph &protos) const;
146144

147145
void dump(llvm::raw_ostream &out) const;
148146

@@ -174,7 +172,7 @@ class Term final {
174172
Atoms.push_back(atom);
175173
}
176174

177-
int compare(const Term &other, ProtocolOrder order) const;
175+
int compare(const Term &other, const ProtocolGraph &protos) const;
178176

179177
size_t size() const { return Atoms.size(); }
180178

@@ -239,22 +237,22 @@ class Rule final {
239237
}
240238

241239
int compare(const Rule &other,
242-
ProtocolOrder protocolOrder) const {
243-
return LHS.compare(other.LHS, protocolOrder);
240+
const ProtocolGraph &protos) const {
241+
return LHS.compare(other.LHS, protos);
244242
}
245243

246244
void dump(llvm::raw_ostream &out) const;
247245
};
248246

249247
class RewriteSystem final {
250248
std::vector<Rule> Rules;
251-
ProtocolOrder Order;
249+
ProtocolGraph Protos;
252250

253251
unsigned DebugSimplify : 1;
254252
unsigned DebugAdd : 1;
255253

256254
public:
257-
explicit RewriteSystem(ProtocolOrder order) : Order(order) {
255+
explicit RewriteSystem() {
258256
DebugSimplify = false;
259257
DebugAdd = false;
260258
}
@@ -264,6 +262,11 @@ class RewriteSystem final {
264262
RewriteSystem &operator=(const RewriteSystem &) = delete;
265263
RewriteSystem &operator=(RewriteSystem &&) = delete;
266264

265+
const ProtocolGraph &getProtocols() const { return Protos; }
266+
267+
void initialize(std::vector<std::pair<Term, Term>> &&rules,
268+
ProtocolGraph &&protos);
269+
267270
bool addRule(Term lhs, Term rhs);
268271

269272
bool simplify(Term &term) const;

lib/AST/ProtocolGraph.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,12 @@ unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {
130130

131131
info.Depth = depth;
132132
return depth;
133-
}
133+
}
134+
135+
int ProtocolGraph::compareProtocols(const ProtocolDecl *lhs,
136+
const ProtocolDecl *rhs) const {
137+
const auto &infoLHS = getProtocolInfo(lhs);
138+
const auto &infoRHS = getProtocolInfo(rhs);
139+
140+
return infoLHS.Index - infoRHS.Index;
141+
}

lib/AST/RequirementMachine.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,10 @@ Term swift::rewriting::getTermForType(CanType paramType,
174174
}
175175

176176
struct RequirementMachine::Implementation {
177-
ProtocolGraph Protocols;
178-
ProtocolOrder Order;
179177
RewriteSystem System;
180178
bool Complete = false;
181179

182-
Implementation()
183-
: Order([&](const ProtocolDecl *lhs,
184-
const ProtocolDecl *rhs) -> int {
185-
const auto &infoLHS = Protocols.getProtocolInfo(lhs);
186-
const auto &infoRHS = Protocols.getProtocolInfo(rhs);
187-
188-
return infoLHS.Index - infoRHS.Index;
189-
}),
190-
System(Order) {}
180+
Implementation() {}
191181
};
192182

193183
RequirementMachine::RequirementMachine(ASTContext &ctx) : Context(ctx) {
@@ -212,15 +202,8 @@ void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
212202
RewriteSystemBuilder builder(Context);
213203
builder.addGenericSignature(sig);
214204

215-
Impl->Protocols = builder.Protocols;
216-
217-
std::sort(builder.Rules.begin(), builder.Rules.end(),
218-
[&](std::pair<Term, Term> lhs,
219-
std::pair<Term, Term> rhs) -> int {
220-
return lhs.first.compare(rhs.first, Impl->Order) < 0;
221-
});
222-
for (const auto &rule : builder.Rules)
223-
Impl->System.addRule(rule.first, rule.second);
205+
Impl->System.initialize(std::move(builder.Rules),
206+
std::move(builder.Protocols));
224207

225208
// FIXME: Add command line flag
226209
auto result = Impl->System.computeConfluentCompletion(1000, 10);

lib/AST/RewriteSystem.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
using namespace swift;
2020
using namespace rewriting;
2121

22-
int Atom::compare(Atom other, ProtocolOrder protocolOrder) const {
22+
int Atom::compare(Atom other, const ProtocolGraph &graph) const {
2323
auto kind = getKind();
2424
auto otherKind = other.getKind();
2525

@@ -31,17 +31,18 @@ int Atom::compare(Atom other, ProtocolOrder protocolOrder) const {
3131
return getName().compare(other.getName());
3232

3333
case Kind::Protocol:
34-
return protocolOrder(getProtocol(), other.getProtocol());
34+
return graph.compareProtocols(getProtocol(), other.getProtocol());
3535

3636
case Kind::AssociatedType: {
3737
auto protos = getProtocols();
3838
auto otherProtos = other.getProtocols();
3939

40+
// Atoms with more protocols are 'smaller' than those with fewer.
4041
if (protos.size() != otherProtos.size())
41-
return otherProtos.size() > protos.size() ? -1 : 1;
42+
return protos.size() > otherProtos.size() ? -1 : 1;
4243

4344
for (unsigned i : indices(protos)) {
44-
int result = protocolOrder(protos[i], otherProtos[i]);
45+
int result = graph.compareProtocols(protos[i], otherProtos[i]);
4546
if (result)
4647
return result;
4748
}
@@ -108,15 +109,15 @@ void Atom::dump(llvm::raw_ostream &out) const {
108109
llvm_unreachable("Bad atom kind");
109110
}
110111

111-
int Term::compare(const Term &other, ProtocolOrder protocolOrder) const {
112+
int Term::compare(const Term &other, const ProtocolGraph &graph) const {
112113
if (size() != other.size())
113114
return size() < other.size() ? -1 : 1;
114115

115116
for (unsigned i = 0, e = size(); i < e; ++i) {
116117
auto lhs = (*this)[i];
117118
auto rhs = other[i];
118119

119-
int result = lhs.compare(rhs, protocolOrder);
120+
int result = lhs.compare(rhs, graph);
120121
if (result != 0) {
121122
assert(lhs != rhs);
122123
return result;
@@ -222,17 +223,30 @@ void Rule::dump(llvm::raw_ostream &out) const {
222223
out << " [deleted]";
223224
}
224225

226+
void RewriteSystem::initialize(std::vector<std::pair<Term, Term>> &&rules,
227+
ProtocolGraph &&graph) {
228+
Protos = graph;
229+
230+
std::sort(rules.begin(), rules.end(),
231+
[&](std::pair<Term, Term> lhs,
232+
std::pair<Term, Term> rhs) -> int {
233+
return lhs.first.compare(rhs.first, graph) < 0;
234+
});
235+
for (const auto &rule : rules)
236+
addRule(rule.first, rule.second);
237+
}
238+
225239
bool RewriteSystem::addRule(Term lhs, Term rhs) {
226240
simplify(lhs);
227241
simplify(rhs);
228242

229-
int result = lhs.compare(rhs, Order);
243+
int result = lhs.compare(rhs, Protos);
230244
if (result == 0)
231245
return false;
232246
if (result < 0)
233247
std::swap(lhs, rhs);
234248

235-
assert(lhs.compare(rhs, Order) > 0);
249+
assert(lhs.compare(rhs, Protos) > 0);
236250

237251
if (DebugAdd) {
238252
llvm::dbgs() << "# Adding rule ";

0 commit comments

Comments
 (0)