Skip to content

Commit 44eb0cf

Browse files
committed
RequirementMachine: Experiment with merging associated types
1 parent 91ab25d commit 44eb0cf

File tree

4 files changed

+187
-1
lines changed

4 files changed

+187
-1
lines changed

include/swift/AST/ProtocolGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace rewriting {
2727

2828
struct ProtocolInfo {
2929
ArrayRef<ProtocolDecl *> Inherited;
30+
llvm::TinyPtrVector<const ProtocolDecl *> AllInherited;
3031
llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
3132
ArrayRef<Requirement> Requirements;
3233

@@ -77,9 +78,14 @@ struct ProtocolGraph {
7778

7879
void computeInheritedAssociatedTypes();
7980

81+
void computeInheritedProtocols();
82+
8083
int compareProtocols(const ProtocolDecl *lhs,
8184
const ProtocolDecl *rhs) const;
8285

86+
bool inheritsFrom(const ProtocolDecl *thisProto,
87+
const ProtocolDecl *otherProto) const;
88+
8389
private:
8490
unsigned computeProtocolDepth(const ProtocolDecl *proto);
8591
};

include/swift/AST/RewriteSystem.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,22 @@ class Term final {
182182
decltype(Atoms)::iterator begin() { return Atoms.begin(); }
183183
decltype(Atoms)::iterator end() { return Atoms.end(); }
184184

185+
const Atom &back() const {
186+
return Atoms.back();
187+
}
188+
189+
Atom &back() {
190+
return Atoms.back();
191+
}
192+
185193
const Atom &operator[](size_t index) const {
186194
return Atoms[index];
187195
}
188196

197+
Atom &operator[](size_t index) {
198+
return Atoms[index];
199+
}
200+
189201
decltype(Atoms)::const_iterator findSubTerm(const Term &other) const;
190202

191203
decltype(Atoms)::iterator findSubTerm(const Term &other);
@@ -210,6 +222,9 @@ class Rule final {
210222
Rule(const Term &lhs, const Term &rhs)
211223
: LHS(lhs), RHS(rhs), deleted(false) {}
212224

225+
const Term &getLHS() const { return LHS; }
226+
const Term &getRHS() const { return RHS; }
227+
213228
bool apply(Term &term) const {
214229
assert(!deleted);
215230
return term.rewriteSubTerm(LHS, RHS);
@@ -247,15 +262,18 @@ class Rule final {
247262
class RewriteSystem final {
248263
std::vector<Rule> Rules;
249264
ProtocolGraph Protos;
265+
std::vector<std::pair<Term, Term>> MergedAssociatedTypes;
250266
std::deque<std::pair<unsigned, unsigned>> Worklist;
251267

252268
unsigned DebugSimplify : 1;
253269
unsigned DebugAdd : 1;
270+
unsigned DebugMerge : 1;
254271

255272
public:
256273
explicit RewriteSystem() {
257274
DebugSimplify = false;
258275
DebugAdd = false;
276+
DebugMerge = false;
259277
}
260278

261279
RewriteSystem(const RewriteSystem &) = delete;
@@ -283,6 +301,10 @@ class RewriteSystem final {
283301
unsigned maxDepth);
284302

285303
void dump(llvm::raw_ostream &out) const;
304+
305+
private:
306+
Atom mergeAssociatedTypes(Atom lhs, Atom rhs) const;
307+
void processMergedAssociatedTypes();
286308
};
287309

288310
} // end namespace rewriting

lib/AST/ProtocolGraph.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void ProtocolGraph::computeLinearOrder() {
8888
}
8989

9090
void ProtocolGraph::computeInheritedAssociatedTypes() {
91-
for (const auto *proto : Protocols) {
91+
for (const auto *proto : llvm::reverse(Protocols)) {
9292
auto &info = Info[proto];
9393

9494
llvm::SmallDenseSet<const AssociatedTypeDecl *, 4> visited;
@@ -108,6 +108,28 @@ void ProtocolGraph::computeInheritedAssociatedTypes() {
108108
}
109109
}
110110

111+
void ProtocolGraph::computeInheritedProtocols() {
112+
for (const auto *proto : llvm::reverse(Protocols)) {
113+
auto &info = Info[proto];
114+
115+
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
116+
visited.insert(proto);
117+
118+
for (const auto *inherited : info.Inherited) {
119+
if (!visited.insert(inherited).second)
120+
continue;
121+
info.AllInherited.push_back(inherited);
122+
123+
for (auto *inheritedType : getProtocolInfo(inherited).AllInherited) {
124+
if (!visited.insert(inheritedType).second)
125+
continue;
126+
127+
info.AllInherited.push_back(inheritedType);
128+
}
129+
}
130+
}
131+
}
132+
111133
unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {
112134
auto &info = Info[proto];
113135

@@ -139,3 +161,11 @@ int ProtocolGraph::compareProtocols(const ProtocolDecl *lhs,
139161

140162
return infoLHS.Index - infoRHS.Index;
141163
}
164+
165+
bool ProtocolGraph::inheritsFrom(const ProtocolDecl *thisProto,
166+
const ProtocolDecl *otherProto) const {
167+
const auto &info = getProtocolInfo(thisProto);
168+
return std::find(info.AllInherited.begin(),
169+
info.AllInherited.end(),
170+
otherProto) != info.AllInherited.end();
171+
}

lib/AST/RewriteSystem.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ void RewriteSystem::initialize(std::vector<std::pair<Term, Term>> &&rules,
227227
ProtocolGraph &&graph) {
228228
Protos = graph;
229229

230+
// FIXME: Probably this sort is not necessary
230231
std::sort(rules.begin(), rules.end(),
231232
[&](std::pair<Term, Term> lhs,
232233
std::pair<Term, Term> rhs) -> int {
@@ -259,6 +260,14 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
259260
unsigned i = Rules.size();
260261
Rules.emplace_back(lhs, rhs);
261262

263+
if (lhs.size() == rhs.size() &&
264+
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
265+
lhs.back().getKind() == Atom::Kind::AssociatedType &&
266+
rhs.back().getKind() == Atom::Kind::AssociatedType &&
267+
lhs.back().getName() == rhs.back().getName()) {
268+
MergedAssociatedTypes.emplace_back(lhs, rhs);
269+
}
270+
262271
for (unsigned j : indices(Rules)) {
263272
if (i == j)
264273
continue;
@@ -310,6 +319,110 @@ bool RewriteSystem::simplify(Term &term) const {
310319
return changed;
311320
}
312321

322+
Atom RewriteSystem::mergeAssociatedTypes(Atom lhs, Atom rhs) const {
323+
assert(lhs.getKind() == Atom::Kind::AssociatedType);
324+
assert(rhs.getKind() == Atom::Kind::AssociatedType);
325+
assert(lhs.getName() == rhs.getName());
326+
assert(lhs.compare(rhs, Protos) > 0);
327+
328+
auto protos = lhs.getProtocols();
329+
auto otherProtos = rhs.getProtocols();
330+
331+
// Follows from lhs > rhs
332+
assert(protos.size() <= otherProtos.size());
333+
334+
llvm::TinyPtrVector<const ProtocolDecl *> newProtos;
335+
std::merge(protos.begin(), protos.end(),
336+
otherProtos.begin(), otherProtos.end(),
337+
std::back_inserter(newProtos),
338+
[&](const ProtocolDecl *lhs,
339+
const ProtocolDecl *rhs) -> int {
340+
return Protos.compareProtocols(lhs, rhs) < 0;
341+
});
342+
343+
llvm::TinyPtrVector<const ProtocolDecl *> minimalProtos;
344+
for (const auto *newProto : newProtos) {
345+
auto inheritsFrom = [&](const ProtocolDecl *thisProto) {
346+
return (thisProto == newProto ||
347+
Protos.inheritsFrom(thisProto, newProto));
348+
};
349+
350+
if (std::find_if(protos.begin(), protos.end(), inheritsFrom)
351+
== protos.end()) {
352+
minimalProtos.push_back(newProto);
353+
}
354+
}
355+
356+
assert(minimalProtos.size() >= protos.size());
357+
assert(minimalProtos.size() >= otherProtos.size());
358+
359+
return Atom::forAssociatedType(minimalProtos, lhs.getName());
360+
}
361+
362+
void RewriteSystem::processMergedAssociatedTypes() {
363+
if (MergedAssociatedTypes.empty())
364+
return;
365+
366+
unsigned i = 0;
367+
while (i < MergedAssociatedTypes.size()) {
368+
auto pair = MergedAssociatedTypes[i++];
369+
const auto &lhs = pair.first;
370+
const auto &rhs = pair.second;
371+
372+
// If we have ...[P1:T] => ...[P2:T], add a new pair of rules:
373+
// ...[P1:T] => ...[P1&P2:T]
374+
// ...[P2:T] => ...[P1&P2:T]
375+
if (DebugMerge) {
376+
llvm::dbgs() << "## Associated type merge candidate ";
377+
lhs.dump(llvm::dbgs());
378+
llvm::dbgs() << " => ";
379+
rhs.dump(llvm::dbgs());
380+
llvm::dbgs() << "\n";
381+
}
382+
383+
auto mergedAtom = mergeAssociatedTypes(lhs.back(), rhs.back());
384+
if (DebugMerge) {
385+
llvm::dbgs() << "### Merged atom ";
386+
mergedAtom.dump(llvm::dbgs());
387+
llvm::dbgs() << "\n";
388+
}
389+
390+
Term mergedTerm = lhs;
391+
mergedTerm.back() = mergedAtom;
392+
393+
addRule(lhs, mergedTerm);
394+
addRule(rhs, mergedTerm);
395+
396+
for (const auto &otherRule : Rules) {
397+
const auto &otherLHS = otherRule.getLHS();
398+
if (otherLHS.size() == 2 &&
399+
otherLHS[1].getKind() == Atom::Kind::Protocol) {
400+
if (otherLHS[0] == lhs.back() ||
401+
otherLHS[0] == rhs.back()) {
402+
if (DebugMerge) {
403+
llvm::dbgs() << "### Lifting conformance rule ";
404+
otherRule.dump(llvm::dbgs());
405+
llvm::dbgs() << "\n";
406+
}
407+
408+
auto otherRHS = otherRule.getRHS();
409+
assert(otherRHS.size() == 1);
410+
assert(otherRHS[0] == otherLHS[0]);
411+
412+
otherRHS.back() = mergedAtom;
413+
414+
auto newLHS = otherRHS;
415+
newLHS.add(Atom::forProtocol(otherLHS[1].getProtocol()));
416+
417+
addRule(newLHS, otherRHS);
418+
}
419+
}
420+
}
421+
}
422+
423+
MergedAssociatedTypes.clear();
424+
}
425+
313426
RewriteSystem::CompletionResult
314427
RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
315428
unsigned maxDepth) {
@@ -361,8 +474,23 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
361474
if (rule.canReduceLeftHandSide(newRule))
362475
rule.markDeleted();
363476
}
477+
478+
processMergedAssociatedTypes();
364479
}
365480

481+
// This isn't necessary for correctness, it's just an optimization.
482+
for (auto &rule : Rules) {
483+
auto rhs = rule.getRHS();
484+
simplify(rhs);
485+
rule = Rule(rule.getLHS(), rhs);
486+
}
487+
488+
// Just for aesthetics in dump().
489+
std::sort(Rules.begin(), Rules.end(),
490+
[&](Rule lhs, Rule rhs) -> int {
491+
return lhs.getLHS().compare(rhs.getLHS(), Protos) < 0;
492+
});
493+
366494
return CompletionResult::Success;
367495
}
368496

0 commit comments

Comments
 (0)