Skip to content

Commit a75327e

Browse files
committed
RequirementMachine: Record homotopy generators when adding new rules during completion
1 parent 65e9dd1 commit a75327e

File tree

3 files changed

+187
-28
lines changed

3 files changed

+187
-28
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,30 +116,76 @@ Symbol RewriteSystem::simplifySubstitutionsInSuperclassOrConcreteSymbol(
116116
}, Context);
117117
}
118118

119-
bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
119+
/// Adds a rewrite rule, returning true if the new rule was non-trivial.
120+
///
121+
/// If both sides simplify to the same term, the rule is trivial and discarded,
122+
/// and this method returns false.
123+
///
124+
/// If \p path is non-null, the new rule is derived from existing rules in the
125+
/// rewrite system; the path records a series of rewrite steps which transform
126+
/// \p lhs to \p rhs.
127+
bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs,
128+
const RewritePath *path) {
120129
assert(!lhs.empty());
121130
assert(!rhs.empty());
122131

123132
if (Debug.contains(DebugFlags::Add)) {
124-
llvm::dbgs() << "# Adding rule " << lhs << " == " << rhs << "\n";
133+
llvm::dbgs() << "# Adding rule " << lhs << " == " << rhs << "\n\n";
125134
}
126135

127136
// Now simplify both sides as much as possible with the rules we have so far.
128137
//
129138
// This avoids unnecessary work in the completion algorithm.
130-
simplify(lhs);
131-
simplify(rhs);
139+
RewritePath lhsPath;
140+
RewritePath rhsPath;
141+
142+
simplify(lhs, &lhsPath);
143+
simplify(rhs, &rhsPath);
144+
145+
RewritePath loop;
146+
if (path) {
147+
// Produce a path from the simplified lhs to the simplified rhs.
148+
149+
// (1) First, apply lhsPath in reverse to produce the original lhs.
150+
lhsPath.invert();
151+
loop.append(lhsPath);
152+
153+
// (2) Now, apply the path from the original lhs to the original rhs
154+
// given to us by the completion procedure.
155+
loop.append(*path);
156+
157+
// (3) Finally, apply rhsPath to produce the simplified rhs, which
158+
// is the same as the simplified lhs.
159+
loop.append(rhsPath);
160+
}
132161

133162
// If the left hand side and right hand side are already equivalent, we're
134163
// done.
135164
int result = lhs.compare(rhs, Protos);
136-
if (result == 0)
165+
if (result == 0) {
166+
// If this rule is a consequence of existing rules, add a homotopy
167+
// generator.
168+
if (path) {
169+
// We already have a loop, since the simplified lhs is identical to the
170+
// simplified rhs.
171+
HomotopyGenerators.emplace_back(lhs, loop);
172+
173+
if (Debug.contains(DebugFlags::Add)) {
174+
llvm::dbgs() << "## Recorded trivial loop at " << lhs << ": ";
175+
loop.dump(llvm::dbgs(), lhs, *this);
176+
llvm::dbgs() << "\n\n";
177+
}
178+
}
179+
137180
return false;
181+
}
138182

139183
// Orient the two terms so that the left hand side is greater than the
140184
// right hand side.
141-
if (result < 0)
185+
if (result < 0) {
142186
std::swap(lhs, rhs);
187+
loop.invert();
188+
}
143189

144190
assert(lhs.compare(rhs, Protos) > 0);
145191

@@ -153,6 +199,19 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
153199
auto uniquedRHS = Term::get(rhs, Context);
154200
Rules.emplace_back(uniquedLHS, uniquedRHS);
155201

202+
if (path) {
203+
// We have a rewrite path from the simplified lhs to the simplified rhs;
204+
// add a rewrite step applying the new rule in reverse to close the loop.
205+
loop.add(RewriteStep(/*offset=*/0, newRuleID, /*inverse=*/true));
206+
HomotopyGenerators.emplace_back(lhs, loop);
207+
208+
if (Debug.contains(DebugFlags::Add)) {
209+
llvm::dbgs() << "## Recorded non-trivial loop at " << lhs << ": ";
210+
loop.dump(llvm::dbgs(), lhs, *this);
211+
llvm::dbgs() << "\n\n";
212+
}
213+
}
214+
156215
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), newRuleID);
157216
if (oldRuleID) {
158217
llvm::errs() << "Duplicate rewrite rule!\n";
@@ -279,8 +338,9 @@ void RewriteSystem::simplifyRewriteSystem() {
279338
continue;
280339

281340
// Now, try to reduce the right hand side.
341+
RewritePath rhsPath;
282342
MutableTerm rhs(rule.getRHS());
283-
if (!simplify(rhs))
343+
if (!simplify(rhs, &rhsPath))
284344
continue;
285345

286346
// We're adding a new rule, so the old rule won't apply anymore.
@@ -293,6 +353,27 @@ void RewriteSystem::simplifyRewriteSystem() {
293353
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), newRuleID);
294354
assert(oldRuleID == ruleID);
295355
(void) oldRuleID;
356+
357+
// Produce a loop at the simplified rhs.
358+
RewritePath loop;
359+
360+
// (1) First, apply rhsPath in reverse to produce the original rhs.
361+
rhsPath.invert();
362+
loop.append(rhsPath);
363+
364+
// (2) Next, apply the original rule in reverse to produce the
365+
// original lhs.
366+
loop.add(RewriteStep(/*offset=*/0, ruleID, /*inverse=*/true));
367+
368+
// (3) Finally, apply the new rule to produce the simplified rhs.
369+
loop.add(RewriteStep(/*offset=*/0, newRuleID, /*inverse=*/false));
370+
371+
if (Debug.contains(DebugFlags::Completion)) {
372+
llvm::dbgs() << "$ Right hand side simplification recorded a loop: ";
373+
loop.dump(llvm::dbgs(), rhs, *this);
374+
}
375+
376+
HomotopyGenerators.emplace_back(rhs, loop);
296377
}
297378
}
298379

@@ -364,4 +445,11 @@ void RewriteSystem::dump(llvm::raw_ostream &out) const {
364445
out << "- " << rule << "\n";
365446
}
366447
out << "}\n";
448+
out << "Homotopy generators: {\n";
449+
for (const auto &loop : HomotopyGenerators) {
450+
out << "- " << loop.first << ": ";
451+
loop.second.dump(out, loop.first, *this);
452+
out << "\n";
453+
}
454+
out << "}\n";
367455
}

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ class RewriteSystem final {
181181
/// Pairs of rules which have already been checked for overlap.
182182
llvm::DenseSet<std::pair<unsigned, unsigned>> CheckedOverlaps;
183183

184+
/// Homotopy generators (2-cells) for this rewrite system. These are the
185+
/// cyclic rewrite paths which rewrite a term back to itself. This
186+
/// data informs the generic signature minimization algorithm.
187+
std::vector<std::pair<MutableTerm, RewritePath>> HomotopyGenerators;
188+
184189
DebugOptions Debug;
185190

186191
public:
@@ -216,7 +221,8 @@ class RewriteSystem final {
216221
return Rules[ruleID];
217222
}
218223

219-
bool addRule(MutableTerm lhs, MutableTerm rhs);
224+
bool addRule(MutableTerm lhs, MutableTerm rhs,
225+
const RewritePath *path=nullptr);
220226

221227
bool simplify(MutableTerm &term, RewritePath *path=nullptr) const;
222228

@@ -249,10 +255,12 @@ class RewriteSystem final {
249255

250256
private:
251257
bool
252-
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
253-
const Rule &lhs, const Rule &rhs,
254-
std::vector<std::pair<MutableTerm,
255-
MutableTerm>> &result) const;
258+
computeCriticalPair(
259+
ArrayRef<Symbol>::const_iterator from,
260+
const Rule &lhs, const Rule &rhs,
261+
std::vector<std::pair<MutableTerm, MutableTerm>> &pairs,
262+
std::vector<RewritePath> &paths,
263+
std::vector<std::pair<MutableTerm, RewritePath>> &loops) const;
256264

257265
void processMergedAssociatedTypes();
258266

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,10 @@ bool
367367
RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
368368
const Rule &lhs, const Rule &rhs,
369369
std::vector<std::pair<MutableTerm,
370-
MutableTerm>> &result) const {
370+
MutableTerm>> &pairs,
371+
std::vector<RewritePath> &paths,
372+
std::vector<std::pair<MutableTerm,
373+
RewritePath>> &loops) const {
371374
auto end = lhs.getLHS().end();
372375
if (from + rhs.getLHS().size() < end) {
373376
// lhs == TUV -> X, rhs == U -> Y.
@@ -378,17 +381,33 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
378381
// In this case, T and V are both empty.
379382

380383
// Compute the term TYV.
381-
MutableTerm t(lhs.getLHS().begin(), from);
382-
t.append(rhs.getRHS());
383-
t.append(from + rhs.getLHS().size(), lhs.getLHS().end());
384-
385-
if (lhs.getRHS().size() == t.size() &&
386-
std::equal(lhs.getRHS().begin(), lhs.getRHS().end(),
387-
t.begin())) {
384+
MutableTerm tyv(lhs.getLHS().begin(), from);
385+
tyv.append(rhs.getRHS());
386+
tyv.append(from + rhs.getLHS().size(), lhs.getLHS().end());
387+
388+
MutableTerm x(lhs.getRHS());
389+
390+
// Compute a path from X to TYV.
391+
RewritePath path;
392+
393+
// (1) First, apply the left hand side rule in the reverse direction.
394+
path.add(RewriteStep(/*offset=*/0,
395+
getRuleID(lhs),
396+
/*inverse=*/true));
397+
// (2) Now, apply the right hand side in the forward direction.
398+
path.add(RewriteStep(from - lhs.getLHS().begin(),
399+
getRuleID(rhs),
400+
/*inverse=*/false));
401+
402+
// If X == TYV, we have a trivial overlap.
403+
if (x == tyv) {
404+
loops.emplace_back(x, path);
388405
return false;
389406
}
390407

391-
result.emplace_back(MutableTerm(lhs.getRHS()), t);
408+
// Add the pair (X, TYV).
409+
pairs.emplace_back(x, tyv);
410+
paths.push_back(path);
392411
} else {
393412
// lhs == TU -> X, rhs == UV -> Y.
394413

@@ -408,10 +427,27 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
408427
// Compute the term TY.
409428
t.append(rhs.getRHS());
410429

411-
if (xv == t)
430+
// Compute a path from XV to TY.
431+
RewritePath path;
432+
433+
// (1) First, apply the left hand side rule in the reverse direction.
434+
path.add(RewriteStep(/*offset=*/0,
435+
getRuleID(lhs),
436+
/*inverse=*/true));
437+
// (2) Now, apply the right hand side in the forward direction.
438+
path.add(RewriteStep(from - lhs.getLHS().begin(),
439+
getRuleID(rhs),
440+
/*inverse=*/false));
441+
442+
// If XV == TY, we have a trivial overlap.
443+
if (xv == t) {
444+
loops.emplace_back(xv, path);
412445
return false;
446+
}
413447

414-
result.emplace_back(xv, t);
448+
// Add the pair (XV, TY).
449+
pairs.emplace_back(xv, t);
450+
paths.push_back(path);
415451
}
416452

417453
return true;
@@ -435,9 +471,11 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
435471

436472
bool again = false;
437473

438-
do {
439-
std::vector<std::pair<MutableTerm, MutableTerm>> resolvedCriticalPairs;
474+
std::vector<std::pair<MutableTerm, MutableTerm>> resolvedCriticalPairs;
475+
std::vector<RewritePath> resolvedPaths;
476+
std::vector<std::pair<MutableTerm, RewritePath>> resolvedLoops;
440477

478+
do {
441479
// For every rule, looking for other rules that overlap with this rule.
442480
for (unsigned i = 0, e = Rules.size(); i < e; ++i) {
443481
const auto &lhs = getRule(i);
@@ -480,9 +518,13 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
480518
}
481519

482520
// Try to repair the confluence violation by adding a new rule.
483-
if (computeCriticalPair(from, lhs, rhs, resolvedCriticalPairs)) {
521+
if (computeCriticalPair(from, lhs, rhs,
522+
resolvedCriticalPairs,
523+
resolvedPaths,
524+
resolvedLoops)) {
484525
if (Debug.contains(DebugFlags::Completion)) {
485526
const auto &pair = resolvedCriticalPairs.back();
527+
const auto &path = resolvedPaths.back();
486528

487529
llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
488530
llvm::dbgs() << lhs << "\n";
@@ -492,13 +534,26 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
492534
<< pair.first << "\n";
493535
llvm::dbgs() << "$$ Second term of critical pair is "
494536
<< pair.second << "\n\n";
537+
538+
llvm::dbgs() << "$$ Resolved via path: ";
539+
path.dump(llvm::dbgs(), pair.first, *this);
540+
llvm::dbgs() << "\n\n";
495541
}
496542
} else {
497543
if (Debug.contains(DebugFlags::Completion)) {
544+
const auto &loop = resolvedLoops.back();
545+
498546
llvm::dbgs() << "$ Trivially overlapping rules: (#" << i << ") ";
499547
llvm::dbgs() << lhs << "\n";
500548
llvm::dbgs() << " -vs- (#" << j << ") ";
501549
llvm::dbgs() << rhs << ":\n";
550+
551+
llvm::dbgs() << "$$ Loop: ";
552+
loop.second.dump(llvm::dbgs(), loop.first, *this);
553+
llvm::dbgs() << "\n\n";
554+
555+
// Record the trivial loop.
556+
HomotopyGenerators.push_back(loop);
502557
}
503558
}
504559
});
@@ -509,13 +564,18 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
509564

510565
simplifyRewriteSystem();
511566

567+
assert(resolvedCriticalPairs.size() == resolvedPaths.size());
568+
512569
again = false;
513-
for (const auto &pair : resolvedCriticalPairs) {
570+
for (unsigned index : indices(resolvedCriticalPairs)) {
571+
const auto &pair = resolvedCriticalPairs[index];
572+
const auto &path = resolvedPaths[index];
573+
514574
// Check if we've already done too much work.
515575
if (Rules.size() > maxIterations)
516576
return std::make_pair(CompletionResult::MaxIterations, steps);
517577

518-
if (!addRule(pair.first, pair.second))
578+
if (!addRule(pair.first, pair.second, &path))
519579
continue;
520580

521581
// Check if the new rule is too long.
@@ -527,6 +587,9 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
527587
again = true;
528588
}
529589

590+
resolvedCriticalPairs.clear();
591+
resolvedPaths.clear();
592+
530593
// If the added rules merged any associated types, process the merges now
531594
// before we continue with the completion procedure. This is important
532595
// to perform incrementally since merging is required to repair confluence

0 commit comments

Comments
 (0)