Skip to content

Commit b2a21b3

Browse files
committed
RequirementMachine: Preliminary plumbing for computing homotopy generators
1 parent e4f6128 commit b2a21b3

File tree

3 files changed

+174
-26
lines changed

3 files changed

+174
-26
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 94 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,70 @@
2323
using namespace swift;
2424
using namespace rewriting;
2525

26+
void Rule::dump(llvm::raw_ostream &out) const {
27+
out << LHS << " => " << RHS;
28+
if (deleted)
29+
out << " [deleted]";
30+
}
31+
32+
void RewritePath::invert() {
33+
std::reverse(Steps.begin(), Steps.end());
34+
35+
for (auto &step : Steps)
36+
step.invert();
37+
}
38+
39+
/// Dumps the rewrite step that was applied to \p term. Mutates \p term to
40+
/// reflect the application of the rule.
41+
void RewriteStep::dump(llvm::raw_ostream &out,
42+
MutableTerm &term,
43+
const RewriteSystem &system) const {
44+
const auto &rule = system.getRule(RuleID);
45+
46+
auto lhs = (Inverse ? rule.getRHS() : rule.getLHS());
47+
auto rhs = (Inverse ? rule.getLHS() : rule.getRHS());
48+
49+
assert(std::equal(term.begin() + Offset,
50+
term.begin() + Offset + lhs.size(),
51+
lhs.begin()));
52+
53+
MutableTerm prefix(term.begin(), term.begin() + Offset);
54+
MutableTerm suffix(term.begin() + Offset + lhs.size(), term.end());
55+
56+
if (!prefix.empty()) {
57+
out << prefix;
58+
out << ".";
59+
}
60+
out << "(" << rule.getLHS();
61+
out << (Inverse ? " <= " : " => ");
62+
out << rule.getRHS() << ")";
63+
if (!suffix.empty()) {
64+
out << ".";
65+
out << suffix;
66+
}
67+
68+
term = prefix;
69+
term.append(rhs);
70+
term.append(suffix);
71+
}
72+
73+
/// Dumps a series of rewrite steps applied to \p term.
74+
void RewritePath::dump(llvm::raw_ostream &out,
75+
MutableTerm term,
76+
const RewriteSystem &system) const {
77+
bool first = true;
78+
79+
for (const auto &step : Steps) {
80+
if (!first) {
81+
out << "";
82+
} else {
83+
first = false;
84+
}
85+
86+
step.dump(out, term, system);
87+
}
88+
}
89+
2690
RewriteSystem::RewriteSystem(RewriteContext &ctx)
2791
: Context(ctx), Debug(ctx.getDebugOptions()) {}
2892

@@ -31,12 +95,6 @@ RewriteSystem::~RewriteSystem() {
3195
Context.RuleTrieRootHistogram);
3296
}
3397

34-
void Rule::dump(llvm::raw_ostream &out) const {
35-
out << LHS << " => " << RHS;
36-
if (deleted)
37-
out << " [deleted]";
38-
}
39-
4098
void RewriteSystem::initialize(
4199
std::vector<std::pair<MutableTerm, MutableTerm>> &&rules,
42100
ProtocolGraph &&graph) {
@@ -89,16 +147,16 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
89147
llvm::dbgs() << "## Simplified and oriented rule " << lhs << " => " << rhs << "\n\n";
90148
}
91149

92-
unsigned i = Rules.size();
150+
unsigned newRuleID = Rules.size();
93151

94152
auto uniquedLHS = Term::get(lhs, Context);
95153
auto uniquedRHS = Term::get(rhs, Context);
96154
Rules.emplace_back(uniquedLHS, uniquedRHS);
97155

98-
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), i);
156+
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), newRuleID);
99157
if (oldRuleID) {
100158
llvm::errs() << "Duplicate rewrite rule!\n";
101-
const auto &oldRule = Rules[*oldRuleID];
159+
const auto &oldRule = getRule(*oldRuleID);
102160
llvm::errs() << "Old rule #" << *oldRuleID << ": ";
103161
oldRule.dump(llvm::errs());
104162
llvm::errs() << "\nTrying to replay what happened when I simplified this term:\n";
@@ -116,11 +174,18 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
116174
}
117175

118176
/// Reduce a term by applying all rewrite rules until fixed point.
119-
bool RewriteSystem::simplify(MutableTerm &term) const {
177+
///
178+
/// If \p path is non-null, records the series of rewrite steps taken.
179+
bool RewriteSystem::simplify(MutableTerm &term, RewritePath *path) const {
120180
bool changed = false;
121181

182+
MutableTerm original;
183+
RewritePath forDebug;
122184
if (Debug.contains(DebugFlags::Simplify)) {
123-
llvm::dbgs() << "= Term " << term << "\n";
185+
186+
original = term;
187+
if (!path)
188+
path = &forDebug;
124189
}
125190

126191
while (true) {
@@ -131,19 +196,16 @@ bool RewriteSystem::simplify(MutableTerm &term) const {
131196
while (from < end) {
132197
auto ruleID = Trie.find(from, end);
133198
if (ruleID) {
134-
const auto &rule = Rules[*ruleID];
199+
const auto &rule = getRule(*ruleID);
135200
if (!rule.isDeleted()) {
136-
if (Debug.contains(DebugFlags::Simplify)) {
137-
llvm::dbgs() << "== Rule #" << *ruleID << ": " << rule << "\n";
138-
}
139-
140201
auto to = from + rule.getLHS().size();
141202
assert(std::equal(from, to, rule.getLHS().begin()));
142203

143204
term.rewriteSubTerm(from, to, rule.getRHS());
144205

145-
if (Debug.contains(DebugFlags::Simplify)) {
146-
llvm::dbgs() << "=== Result " << term << "\n";
206+
if (path) {
207+
unsigned offset = (unsigned)(from - term.begin());
208+
path->add(RewriteStep(offset, *ruleID, /*inverse=*/false));
147209
}
148210

149211
changed = true;
@@ -159,6 +221,17 @@ bool RewriteSystem::simplify(MutableTerm &term) const {
159221
break;
160222
}
161223

224+
if (Debug.contains(DebugFlags::Simplify)) {
225+
if (changed) {
226+
llvm::dbgs() << "= Simplified " << term << ": ";
227+
forDebug.dump(llvm::dbgs(), original, *this);
228+
llvm::dbgs() << "\n";
229+
} else {
230+
llvm::dbgs() << "= Irreducible term: " << term << "\n";
231+
}
232+
}
233+
234+
assert(path == nullptr || changed != path->empty());
162235
return changed;
163236
}
164237

@@ -170,7 +243,7 @@ bool RewriteSystem::simplify(MutableTerm &term) const {
170243
/// rules is only valid to perform if the rewrite system is confluent.
171244
void RewriteSystem::simplifyRewriteSystem() {
172245
for (auto ruleID : indices(Rules)) {
173-
auto &rule = Rules[ruleID];
246+
auto &rule = getRule(ruleID);
174247
if (rule.isDeleted())
175248
continue;
176249

@@ -186,11 +259,11 @@ void RewriteSystem::simplifyRewriteSystem() {
186259
continue;
187260

188261
// Ignore other deleted rules.
189-
if (Rules[*otherRuleID].isDeleted())
262+
if (getRule(*otherRuleID).isDeleted())
190263
continue;
191264

192265
if (Debug.contains(DebugFlags::Completion)) {
193-
const auto &otherRule = Rules[ruleID];
266+
const auto &otherRule = getRule(ruleID);
194267
llvm::dbgs() << "$ Deleting rule " << rule << " because "
195268
<< "its left hand side contains " << otherRule
196269
<< "\n";

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace rewriting {
3131

3232
class PropertyMap;
3333
class RewriteContext;
34+
class RewriteSystem;
3435

3536
/// A rewrite rule that replaces occurrences of LHS with RHS.
3637
///
@@ -78,6 +79,67 @@ class Rule final {
7879
}
7980
};
8081

82+
/// Records the application of a rewrite rule to a term.
83+
struct RewriteStep {
84+
/// The position within the term where the rule is being applied.
85+
unsigned Offset : 16;
86+
87+
/// The index of the rule in the rewrite system.
88+
unsigned RuleID : 15;
89+
90+
/// If false, the step replaces an occurrence of the rule's left hand side
91+
/// with the right hand side. If true, vice versa.
92+
unsigned Inverse : 1;
93+
94+
RewriteStep(unsigned offset, unsigned ruleID, bool inverse) {
95+
Offset = offset;
96+
assert(Offset == offset && "Overflow");
97+
RuleID = ruleID;
98+
assert(RuleID == ruleID && "Overflow");
99+
Inverse = inverse;
100+
}
101+
102+
void invert() {
103+
Inverse = !Inverse;
104+
}
105+
106+
void dump(llvm::raw_ostream &out,
107+
MutableTerm &term,
108+
const RewriteSystem &system) const;
109+
};
110+
111+
/// Records a sequence of zero or more rewrite rules applied to a term.
112+
struct RewritePath {
113+
SmallVector<RewriteStep, 3> Steps;
114+
115+
bool empty() const {
116+
return Steps.empty();
117+
}
118+
119+
void add(RewriteStep step) {
120+
Steps.push_back(step);
121+
}
122+
123+
// Horizontal composition of paths.
124+
void append(RewritePath other) {
125+
Steps.append(other.begin(), other.end());
126+
}
127+
128+
decltype(Steps)::const_iterator begin() const {
129+
return Steps.begin();
130+
}
131+
132+
decltype(Steps)::const_iterator end() const {
133+
return Steps.end();
134+
}
135+
136+
void invert();
137+
138+
void dump(llvm::raw_ostream &out,
139+
MutableTerm term,
140+
const RewriteSystem &system) const;
141+
};
142+
81143
/// A term rewrite system for working with types in a generic signature.
82144
///
83145
/// Out-of-line methods are documented in RewriteSystem.cpp.
@@ -141,9 +203,22 @@ class RewriteSystem final {
141203

142204
Symbol simplifySubstitutionsInSuperclassOrConcreteSymbol(Symbol symbol) const;
143205

206+
unsigned getRuleID(const Rule &rule) const {
207+
assert((unsigned)(&rule - &*Rules.begin()) < Rules.size());
208+
return (unsigned)(&rule - &*Rules.begin());
209+
}
210+
211+
Rule &getRule(unsigned ruleID) {
212+
return Rules[ruleID];
213+
}
214+
215+
const Rule &getRule(unsigned ruleID) const {
216+
return Rules[ruleID];
217+
}
218+
144219
bool addRule(MutableTerm lhs, MutableTerm rhs);
145220

146-
bool simplify(MutableTerm &term) const;
221+
bool simplify(MutableTerm &term, RewritePath *path=nullptr) const;
147222

148223
enum class CompletionResult {
149224
/// Confluent completion was computed successfully.

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ void RewriteSystem::processMergedAssociatedTypes() {
235235

236236
// Look for conformance requirements on [P1:T] and [P2:T].
237237
auto visitRule = [&](unsigned ruleID) {
238-
const auto &otherRule = Rules[ruleID];
238+
const auto &otherRule = getRule(ruleID);
239239
const auto &otherLHS = otherRule.getLHS();
240240
if (otherLHS.size() == 2 &&
241241
otherLHS[1].getKind() == Symbol::Kind::Protocol) {
@@ -440,7 +440,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
440440

441441
// For every rule, looking for other rules that overlap with this rule.
442442
for (unsigned i = 0, e = Rules.size(); i < e; ++i) {
443-
const auto &lhs = Rules[i];
443+
const auto &lhs = getRule(i);
444444
if (lhs.isDeleted())
445445
continue;
446446

@@ -458,7 +458,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
458458
if (!CheckedOverlaps.insert(std::make_pair(i, j)).second)
459459
return;
460460

461-
const auto &rhs = Rules[j];
461+
const auto &rhs = getRule(j);
462462
if (rhs.isDeleted())
463463
return;
464464

@@ -497,7 +497,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
497497
if (Debug.contains(DebugFlags::Completion)) {
498498
llvm::dbgs() << "$ Trivially overlapping rules: (#" << i << ") ";
499499
llvm::dbgs() << lhs << "\n";
500-
llvm::dbgs() << " -vs- (#" << j << ") ";
500+
llvm::dbgs() << " -vs- (#" << j << ") ";
501501
llvm::dbgs() << rhs << ":\n";
502502
}
503503
}

0 commit comments

Comments
 (0)