Skip to content

Commit 5cbfbae

Browse files
committed
RequirementMachine: Represent concrete type adjustment in a rewrite path
1 parent b317a5c commit 5cbfbae

File tree

3 files changed

+128
-33
lines changed

3 files changed

+128
-33
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ void RewritePath::invert() {
3636
step.invert();
3737
}
3838

39-
AppliedRewriteStep RewriteStep::apply(MutableTerm &term,
40-
const RewriteSystem &system) const {
39+
AppliedRewriteStep
40+
RewriteStep::applyRewriteRule(MutableTerm &term,
41+
const RewriteSystem &system) const {
42+
assert(Kind == ApplyRewriteRule);
43+
4144
const auto &rule = system.getRule(RuleID);
4245

4346
auto lhs = (Inverse ? rule.getRHS() : rule.getLHS());
@@ -63,21 +66,69 @@ AppliedRewriteStep RewriteStep::apply(MutableTerm &term,
6366
return {lhs, rhs, prefix, suffix};
6467
}
6568

69+
MutableTerm RewriteStep::applyAdjustment(MutableTerm &term,
70+
const RewriteSystem &system) const {
71+
assert(Kind == AdjustConcreteType);
72+
73+
auto &ctx = system.getRewriteContext();
74+
MutableTerm prefix(term.begin(), term.begin() + Offset);
75+
76+
// We're either adding or removing the prefix to each concrete substitution.
77+
term.back() = term.back().transformConcreteSubstitutions(
78+
[&](Term t) -> Term {
79+
if (Inverse) {
80+
if (!std::equal(t.begin(),
81+
t.begin() + Offset,
82+
prefix.begin())) {
83+
llvm::errs() << "Invalid rewrite path\n";
84+
llvm::errs() << "- Term: " << term << "\n";
85+
llvm::errs() << "- Offset: " << Offset << "\n";
86+
llvm::errs() << "- Expected subterm: " << prefix << "\n";
87+
abort();
88+
}
89+
90+
MutableTerm mutTerm(t.begin() + Offset, t.end());
91+
return Term::get(mutTerm, ctx);
92+
} else {
93+
MutableTerm mutTerm(prefix);
94+
mutTerm.append(t);
95+
return Term::get(mutTerm, ctx);
96+
}
97+
}, ctx);
98+
99+
return prefix;
100+
}
101+
66102
/// Dumps the rewrite step that was applied to \p term. Mutates \p term to
67103
/// reflect the application of the rule.
68104
void RewriteStep::dump(llvm::raw_ostream &out,
69105
MutableTerm &term,
70106
const RewriteSystem &system) const {
71-
auto result = apply(term, system);
107+
switch (Kind) {
108+
case ApplyRewriteRule: {
109+
auto result = applyRewriteRule(term, system);
72110

73-
if (!result.prefix.empty()) {
74-
out << result.prefix;
75-
out << ".";
111+
if (!result.prefix.empty()) {
112+
out << result.prefix;
113+
out << ".";
114+
}
115+
out << "(" << result.lhs << " => " << result.rhs << ")";
116+
if (!result.suffix.empty()) {
117+
out << ".";
118+
out << result.suffix;
119+
}
120+
121+
break;
122+
}
123+
case AdjustConcreteType: {
124+
auto result = applyAdjustment(term, system);
125+
126+
out << "";
127+
out << (Inverse ? " - " : " + ");
128+
out << result << ")";
129+
130+
break;
76131
}
77-
out << "(" << result.lhs << " => " << result.rhs << ")";
78-
if (!result.suffix.empty()) {
79-
out << ".";
80-
out << result.suffix;
81132
}
82133
}
83134

@@ -213,7 +264,7 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs,
213264
if (path) {
214265
// We have a rewrite path from the simplified lhs to the simplified rhs;
215266
// add a rewrite step applying the new rule in reverse to close the loop.
216-
loop.add(RewriteStep(/*offset=*/0, newRuleID, /*inverse=*/true));
267+
loop.add(RewriteStep::forRewriteRule(/*offset=*/0, newRuleID, /*inverse=*/true));
217268
HomotopyGenerators.emplace_back(lhs, loop);
218269

219270
if (Debug.contains(DebugFlags::Add)) {
@@ -275,7 +326,8 @@ bool RewriteSystem::simplify(MutableTerm &term, RewritePath *path) const {
275326

276327
if (path) {
277328
unsigned offset = (unsigned)(from - term.begin());
278-
path->add(RewriteStep(offset, *ruleID, /*inverse=*/false));
329+
path->add(RewriteStep::forRewriteRule(offset, *ruleID,
330+
/*inverse=*/false));
279331
}
280332

281333
changed = true;
@@ -374,10 +426,12 @@ void RewriteSystem::simplifyRewriteSystem() {
374426

375427
// (2) Next, apply the original rule in reverse to produce the
376428
// original lhs.
377-
loop.add(RewriteStep(/*offset=*/0, ruleID, /*inverse=*/true));
429+
loop.add(RewriteStep::forRewriteRule(/*offset=*/0, ruleID,
430+
/*inverse=*/true));
378431

379432
// (3) Finally, apply the new rule to produce the simplified rhs.
380-
loop.add(RewriteStep(/*offset=*/0, newRuleID, /*inverse=*/false));
433+
loop.add(RewriteStep::forRewriteRule(/*offset=*/0, newRuleID,
434+
/*inverse=*/false));
381435

382436
if (Debug.contains(DebugFlags::Completion)) {
383437
llvm::dbgs() << "$ Right hand side simplification recorded a loop: ";
@@ -456,7 +510,14 @@ void RewriteSystem::verifyHomotopyGenerators() const {
456510
auto term = loop.first;
457511

458512
for (const auto &step : loop.second) {
459-
(void) step.apply(term, *this);
513+
switch (step.Kind) {
514+
case RewriteStep::ApplyRewriteRule:
515+
(void) step.applyRewriteRule(term, *this);
516+
break;
517+
case RewriteStep::AdjustConcreteType:
518+
(void) step.applyAdjustment(term, *this);
519+
break;
520+
}
460521
}
461522

462523
if (term != loop.first) {

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,19 @@ struct AppliedRewriteStep {
9696
/// Similarly, going in the other direction, if we start from A.Y.B and apply
9797
/// the inverse rule, we get A.(Y => X).B.
9898
struct RewriteStep {
99+
enum StepKind {
100+
/// Apply a rewrite rule at the stored offset.
101+
ApplyRewriteRule,
102+
103+
/// Prepend the prefix to each concrete substitution.
104+
AdjustConcreteType
105+
};
106+
107+
/// The rewrite step kind.
108+
unsigned Kind : 1;
109+
99110
/// The position within the term where the rule is being applied.
100-
unsigned Offset : 16;
111+
unsigned Offset : 15;
101112

102113
/// The index of the rule in the rewrite system.
103114
unsigned RuleID : 15;
@@ -106,20 +117,33 @@ struct RewriteStep {
106117
/// with the right hand side. If true, vice versa.
107118
unsigned Inverse : 1;
108119

109-
RewriteStep(unsigned offset, unsigned ruleID, bool inverse) {
120+
RewriteStep(StepKind kind, unsigned offset, unsigned ruleID, bool inverse) {
121+
Kind = unsigned(kind);
122+
110123
Offset = offset;
111124
assert(Offset == offset && "Overflow");
112125
RuleID = ruleID;
113126
assert(RuleID == ruleID && "Overflow");
114127
Inverse = inverse;
115128
}
116129

130+
static RewriteStep forRewriteRule(unsigned offset, unsigned ruleID, bool inverse) {
131+
return RewriteStep(ApplyRewriteRule, offset, ruleID, inverse);
132+
}
133+
134+
static RewriteStep forAdjustment(unsigned offset, bool inverse) {
135+
return RewriteStep(AdjustConcreteType, offset, /*ruleID=*/0, inverse);
136+
}
137+
117138
void invert() {
118139
Inverse = !Inverse;
119140
}
120141

121-
AppliedRewriteStep apply(MutableTerm &term,
122-
const RewriteSystem &system) const;
142+
AppliedRewriteStep applyRewriteRule(MutableTerm &term,
143+
const RewriteSystem &system) const;
144+
145+
MutableTerm applyAdjustment(MutableTerm &term,
146+
const RewriteSystem &system) const;
123147

124148
void dump(llvm::raw_ostream &out,
125149
MutableTerm &term,

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -391,13 +391,13 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
391391
RewritePath path;
392392

393393
// (1) First, apply the left hand side rule in the reverse direction.
394-
path.add(RewriteStep(/*offset=*/0,
395-
getRuleID(lhs),
396-
/*inverse=*/true));
394+
path.add(RewriteStep::forRewriteRule(/*offset=*/0,
395+
getRuleID(lhs),
396+
/*inverse=*/true));
397397
// (2) Now, apply the right hand side in the forward direction.
398-
path.add(RewriteStep(from - lhs.getLHS().begin(),
399-
getRuleID(rhs),
400-
/*inverse=*/false));
398+
path.add(RewriteStep::forRewriteRule(from - lhs.getLHS().begin(),
399+
getRuleID(rhs),
400+
/*inverse=*/false));
401401

402402
// If X == TYV, we have a trivial overlap.
403403
if (x == tyv) {
@@ -419,7 +419,8 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
419419
xv.append(rhs.getLHS().begin() + (lhs.getLHS().end() - from),
420420
rhs.getLHS().end());
421421

422-
if (xv.back().isSuperclassOrConcreteType()) {
422+
if (xv.back().isSuperclassOrConcreteType() &&
423+
lhs.getLHS().begin() != from) {
423424
xv.back() = xv.back().prependPrefixToConcreteSubstitutions(
424425
t, Context);
425426
}
@@ -431,13 +432,22 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
431432
RewritePath path;
432433

433434
// (1) First, apply the left hand side rule in the reverse direction.
434-
path.add(RewriteStep(/*offset=*/0,
435-
getRuleID(lhs),
436-
/*inverse=*/true));
437-
// (2) Now, apply the right hand side in the forward direction.
438-
path.add(RewriteStep(from - lhs.getLHS().begin(),
439-
getRuleID(rhs),
440-
/*inverse=*/false));
435+
path.add(RewriteStep::forRewriteRule(/*offset=*/0,
436+
getRuleID(lhs),
437+
/*inverse=*/true));
438+
439+
// (2) Next, if the right hand side rule ends with a concrete type symbol,
440+
// perform the concrete type adjustment.
441+
if (xv.back().isSuperclassOrConcreteType() &&
442+
lhs.getLHS().begin() != from) {
443+
path.add(RewriteStep::forAdjustment(from - lhs.getLHS().begin(),
444+
/*inverse=*/true));
445+
}
446+
447+
// (3) Finally, apply the right hand side in the forward direction.
448+
path.add(RewriteStep::forRewriteRule(from - lhs.getLHS().begin(),
449+
getRuleID(rhs),
450+
/*inverse=*/false));
441451

442452
// If XV == TY, we have a trivial overlap.
443453
if (xv == t) {

0 commit comments

Comments
 (0)