Skip to content

Commit 92ac06a

Browse files
committed
RequirementMachine: Rules store uniqued Terms
1 parent ed966e7 commit 92ac06a

File tree

8 files changed

+86
-73
lines changed

8 files changed

+86
-73
lines changed

lib/AST/RequirementMachine/PropertyMap.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,17 +1063,22 @@ RewriteSystem::buildPropertyMap(PropertyMap &map,
10631063
if (rule.isDeleted())
10641064
continue;
10651065

1066-
const auto &lhs = rule.getLHS();
1066+
auto lhs = rule.getLHS();
1067+
auto rhs = rule.getRHS();
10671068

10681069
// Collect all rules of the form T.[p] => T where T is canonical.
10691070
auto property = lhs.back();
10701071
if (!property.isProperty())
10711072
continue;
10721073

1073-
MutableTerm key(lhs.begin(), lhs.end() - 1);
1074-
if (key != rule.getRHS())
1074+
if (lhs.size() - 1 != rhs.size())
10751075
continue;
10761076

1077+
if (!std::equal(rhs.begin(), rhs.end(), lhs.begin()))
1078+
continue;
1079+
1080+
MutableTerm key(rhs);
1081+
10771082
#ifndef NDEBUG
10781083
assert(!simplify(key) &&
10791084
"Right hand side of a property rule should already be reduced");

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
8989
}
9090

9191
unsigned i = Rules.size();
92-
Rules.emplace_back(lhs, rhs);
92+
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
9393

9494
// Check if we have a rule of the form
9595
//
@@ -171,9 +171,11 @@ void RewriteSystem::simplifyRightHandSides() {
171171
if (rule.isDeleted())
172172
continue;
173173

174-
auto rhs = rule.getRHS();
175-
simplify(rhs);
176-
rule = Rule(rule.getLHS(), rhs);
174+
MutableTerm rhs(rule.getRHS());
175+
if (!simplify(rhs))
176+
continue;
177+
178+
rule = Rule(rule.getLHS(), Term::get(rhs, Context));
177179
}
178180

179181
#ifndef NDEBUG

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ class RewriteContext;
3636
///
3737
/// Out-of-line methods are documented in RewriteSystem.cpp.
3838
class Rule final {
39-
MutableTerm LHS;
40-
MutableTerm RHS;
39+
Term LHS;
40+
Term RHS;
4141
bool deleted;
4242

4343
public:
44-
Rule(const MutableTerm &lhs, const MutableTerm &rhs)
44+
Rule(Term lhs, Term rhs)
4545
: LHS(lhs), RHS(rhs), deleted(false) {}
4646

47-
const MutableTerm &getLHS() const { return LHS; }
48-
const MutableTerm &getRHS() const { return RHS; }
47+
const Term &getLHS() const { return LHS; }
48+
const Term &getRHS() const { return RHS; }
4949

5050
bool apply(MutableTerm &term) const {
5151
return term.rewriteSubTerm(LHS, RHS);
@@ -81,12 +81,6 @@ class Rule final {
8181
return LHS.size();
8282
}
8383

84-
/// Partial order on rules orders rules by their left hand side.
85-
int compare(const Rule &other,
86-
const ProtocolGraph &protos) const {
87-
return LHS.compare(other.LHS, protos);
88-
}
89-
9084
void dump(llvm::raw_ostream &out) const;
9185

9286
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &out,

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,9 @@ Symbol Symbol::prependPrefixToConcreteSubstitutions(
8484
/// Note that this relation is not commutative; we need to check
8585
/// for overlap between both (X and Y) and (Y and X).
8686
OverlapKind
87-
MutableTerm::checkForOverlap(const MutableTerm &other,
88-
MutableTerm &t,
89-
MutableTerm &v) const {
90-
assert(!empty());
91-
assert(!other.empty());
92-
87+
Term::checkForOverlap(Term other,
88+
MutableTerm &t,
89+
MutableTerm &v) const {
9390
if (*this == other) {
9491
// If this term is equal to the other term, we have an overlap.
9592
t = MutableTerm();
@@ -410,7 +407,7 @@ RewriteSystem::computeCriticalPair(const Rule &lhs, const Rule &rhs) const {
410407
// Compute the term TYV.
411408
t.append(rhs.getRHS());
412409
t.append(v);
413-
return std::make_pair(lhs.getRHS(), t);
410+
return std::make_pair(MutableTerm(lhs.getRHS()), t);
414411
}
415412

416413
case OverlapKind::Second: {

lib/AST/RequirementMachine/Symbol.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,34 @@ Symbol Symbol::forConcreteType(CanType type, ArrayRef<Term> substitutions,
418418
return symbol;
419419
}
420420

421+
/// Given that this symbol is the first symbol of a term, return the
422+
/// "domain" of the term.
423+
///
424+
/// - If the first symbol is a protocol symbol [P], the domain is P.
425+
/// - If the first symbol is an associated type symbol [P1&...&Pn],
426+
/// the domain is {P1, ..., Pn}.
427+
/// - If the first symbol is a generic parameter symbol, the domain is
428+
/// the empty set {}.
429+
/// - Anything else will assert.
430+
ArrayRef<const ProtocolDecl *> Symbol::getRootProtocols() const {
431+
switch (getKind()) {
432+
case Symbol::Kind::Protocol:
433+
case Symbol::Kind::AssociatedType:
434+
return getProtocols();
435+
436+
case Symbol::Kind::GenericParam:
437+
return ArrayRef<const ProtocolDecl *>();
438+
439+
case Symbol::Kind::Name:
440+
case Symbol::Kind::Layout:
441+
case Symbol::Kind::Superclass:
442+
case Symbol::Kind::ConcreteType:
443+
break;
444+
}
445+
446+
llvm_unreachable("Bad root symbol");
447+
}
448+
421449
/// Linear order on symbols.
422450
///
423451
/// First, we order different kinds as follows, from smallest to largest:

lib/AST/RequirementMachine/Symbol.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ class Symbol final {
191191
ArrayRef<Term> substitutions,
192192
RewriteContext &ctx);
193193

194+
ArrayRef<const ProtocolDecl *> getRootProtocols() const;
195+
194196
int compare(Symbol other, const ProtocolGraph &protos) const;
195197

196198
Symbol transformConcreteSubstitutions(

lib/AST/RequirementMachine/Term.cpp

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -106,42 +106,22 @@ Term Term::get(const MutableTerm &mutableTerm, RewriteContext &ctx) {
106106
return term;
107107
}
108108

109+
/// Find the start of \p other in this term, returning end() if
110+
/// \p other does not occur as a subterm of this term.
111+
ArrayRef<Symbol>::iterator Term::findSubTerm(Term other) const {
112+
if (other.size() > size())
113+
return end();
114+
115+
return std::search(begin(), end(), other.begin(), other.end());
116+
}
117+
109118
void Term::Storage::Profile(llvm::FoldingSetNodeID &id) const {
110119
id.AddInteger(Size);
111120

112121
for (auto symbol : getElements())
113122
id.AddPointer(symbol.getOpaquePointer());
114123
}
115124

116-
/// Returns the "domain" of this term by looking at the first symbol.
117-
///
118-
/// - If the first symbol is a protocol symbol [P], the domain is P.
119-
/// - If the first symbol is an associated type symbol [P1&...&Pn],
120-
/// the domain is {P1, ..., Pn}.
121-
/// - If the first symbol is a generic parameter symbol, the domain is
122-
/// the empty set {}.
123-
/// - Anything else will assert.
124-
ArrayRef<const ProtocolDecl *> MutableTerm::getRootProtocols() const {
125-
auto symbol = *begin();
126-
127-
switch (symbol.getKind()) {
128-
case Symbol::Kind::Protocol:
129-
case Symbol::Kind::AssociatedType:
130-
return symbol.getProtocols();
131-
132-
case Symbol::Kind::GenericParam:
133-
return ArrayRef<const ProtocolDecl *>();
134-
135-
case Symbol::Kind::Name:
136-
case Symbol::Kind::Layout:
137-
case Symbol::Kind::Superclass:
138-
case Symbol::Kind::ConcreteType:
139-
break;
140-
}
141-
142-
llvm_unreachable("Bad root symbol");
143-
}
144-
145125
/// Shortlex order on terms.
146126
///
147127
/// First we compare length, then perform a lexicographic comparison
@@ -170,7 +150,7 @@ int MutableTerm::compare(const MutableTerm &other,
170150
/// Find the start of \p other in this term, returning end() if
171151
/// \p other does not occur as a subterm of this term.
172152
decltype(MutableTerm::Symbols)::const_iterator
173-
MutableTerm::findSubTerm(const MutableTerm &other) const {
153+
MutableTerm::findSubTerm(Term other) const {
174154
if (other.size() > size())
175155
return end();
176156

@@ -179,7 +159,7 @@ MutableTerm::findSubTerm(const MutableTerm &other) const {
179159

180160
/// Non-const variant of the above.
181161
decltype(MutableTerm::Symbols)::iterator
182-
MutableTerm::findSubTerm(const MutableTerm &other) {
162+
MutableTerm::findSubTerm(Term other) {
183163
if (other.size() > size())
184164
return end();
185165

@@ -191,8 +171,7 @@ MutableTerm::findSubTerm(const MutableTerm &other) {
191171
/// order on terms. Returns true if the term contained \p lhs;
192172
/// otherwise returns false, in which case the term remains
193173
/// unchanged.
194-
bool MutableTerm::rewriteSubTerm(const MutableTerm &lhs,
195-
const MutableTerm &rhs) {
174+
bool MutableTerm::rewriteSubTerm(Term lhs, Term rhs) {
196175
// Find the start of lhs in this term.
197176
auto found = findSubTerm(lhs);
198177

lib/AST/RequirementMachine/Term.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ class Term final {
7777

7878
static Term get(const MutableTerm &term, RewriteContext &ctx);
7979

80+
OverlapKind checkForOverlap(Term other,
81+
MutableTerm &t,
82+
MutableTerm &v) const;
83+
84+
ArrayRef<Symbol>::iterator findSubTerm(Term other) const;
85+
86+
/// Returns true if this term contains, or is equal to, \p other.
87+
bool containsSubTerm(Term other) const {
88+
return findSubTerm(other) != end();
89+
}
90+
91+
ArrayRef<const ProtocolDecl *> getRootProtocols() const {
92+
return begin()->getRootProtocols();
93+
}
94+
8095
void dump(llvm::raw_ostream &out) const;
8196

8297
friend bool operator==(Term lhs, Term rhs) {
@@ -144,7 +159,9 @@ class MutableTerm final {
144159

145160
size_t size() const { return Symbols.size(); }
146161

147-
ArrayRef<const ProtocolDecl *> getRootProtocols() const;
162+
ArrayRef<const ProtocolDecl *> getRootProtocols() const {
163+
return begin()->getRootProtocols();
164+
}
148165

149166
decltype(Symbols)::const_iterator begin() const { return Symbols.begin(); }
150167
decltype(Symbols)::const_iterator end() const { return Symbols.end(); }
@@ -174,22 +191,11 @@ class MutableTerm final {
174191
return Symbols[index];
175192
}
176193

177-
decltype(Symbols)::const_iterator findSubTerm(
178-
const MutableTerm &other) const;
194+
decltype(Symbols)::const_iterator findSubTerm(Term other) const;
179195

180-
decltype(Symbols)::iterator findSubTerm(
181-
const MutableTerm &other);
196+
decltype(Symbols)::iterator findSubTerm(Term other);
182197

183-
/// Returns true if this term contains, or is equal to, \p other.
184-
bool containsSubTerm(const MutableTerm &other) const {
185-
return findSubTerm(other) != end();
186-
}
187-
188-
bool rewriteSubTerm(const MutableTerm &lhs, const MutableTerm &rhs);
189-
190-
OverlapKind checkForOverlap(const MutableTerm &other,
191-
MutableTerm &t,
192-
MutableTerm &v) const;
198+
bool rewriteSubTerm(Term lhs, Term rhs);
193199

194200
void dump(llvm::raw_ostream &out) const;
195201

0 commit comments

Comments
 (0)