Skip to content

Commit c32ec95

Browse files
committed
RequirementMachine: Enforce that rewrite rules preserve the term's 'domain'
Intuitively, the first token of a rewrite rule determines if the rule applies to a protocol requirement signature 'Self', or a generic parameter in the top-level generic signature. A rewrite rule can never take a type starting with the protocol 'Self' to a type starting with a generic parameter, or vice versa. Enforce this by defining the notion of a 'domain', which is the set of protocols to which the first atom in a term applies to. This set can be empty (if we have a generic parameter atom), or it may contain more than one element (if we have an associated type atom for a merged associated type).
1 parent 6077aca commit c32ec95

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,16 @@ const ProtocolDecl *Atom::getProtocol() const {
139139
return Ptr->Proto;
140140
}
141141

142-
/// Get the list of protocols associated with an associated type atom.
142+
/// Get the list of protocols associated with a protocol or associated type
143+
/// atom. Note that if this is a protocol atom, the return value will have
144+
/// exactly one element.
143145
ArrayRef<const ProtocolDecl *> Atom::getProtocols() const {
144-
assert(getKind() == Kind::AssociatedType);
145146
auto protos = Ptr->getProtocols();
146-
assert(!protos.empty());
147+
if (protos.empty()) {
148+
assert(getKind() == Kind::Protocol);
149+
return {&Ptr->Proto, 1};
150+
}
151+
assert(getKind() == Kind::AssociatedType);
147152
return protos;
148153
}
149154

@@ -747,6 +752,35 @@ void Term::Storage::Profile(llvm::FoldingSetNodeID &id) const {
747752
id.AddPointer(atom.getOpaquePointer());
748753
}
749754

755+
/// Returns the "domain" of this term by looking at the first atom.
756+
///
757+
/// - If the first atom is a protocol atom [P], the domain is P.
758+
/// - If the first atom is an associated type atom [P1&...&Pn],
759+
/// the domain is {P1, ..., Pn}.
760+
/// - If the first atom is a generic parameter atom, the domain is
761+
/// the empty set {}.
762+
/// - Anything else will assert.
763+
ArrayRef<const ProtocolDecl *> MutableTerm::getRootProtocols() const {
764+
auto atom = *begin();
765+
766+
switch (atom.getKind()) {
767+
case Atom::Kind::Protocol:
768+
case Atom::Kind::AssociatedType:
769+
return atom.getProtocols();
770+
771+
case Atom::Kind::GenericParam:
772+
return ArrayRef<const ProtocolDecl *>();
773+
774+
case Atom::Kind::Name:
775+
case Atom::Kind::Layout:
776+
case Atom::Kind::Superclass:
777+
case Atom::Kind::ConcreteType:
778+
break;
779+
}
780+
781+
llvm_unreachable("Bad root atom");
782+
}
783+
750784
/// Linear order on terms.
751785
///
752786
/// First we compare length, then perform a lexicographic comparison
@@ -1253,7 +1287,8 @@ void RewriteSystem::simplifyRightHandSides() {
12531287

12541288
#define ASSERT_RULE(expr) \
12551289
if (!(expr)) { \
1256-
llvm::errs() << "&&& Malformed rewrite rule: " << rule << "\n\n"; \
1290+
llvm::errs() << "&&& Malformed rewrite rule: " << rule << "\n"; \
1291+
llvm::errs() << "&&& " << #expr << "\n\n"; \
12571292
dump(llvm::errs()); \
12581293
assert(expr); \
12591294
}
@@ -1298,6 +1333,11 @@ void RewriteSystem::simplifyRightHandSides() {
12981333
ASSERT_RULE(atom.getKind() != Atom::Kind::Protocol);
12991334
}
13001335
}
1336+
1337+
auto lhsDomain = lhs.getRootProtocols();
1338+
auto rhsDomain = rhs.getRootProtocols();
1339+
1340+
ASSERT_RULE(lhsDomain == rhsDomain);
13011341
}
13021342

13031343
#undef ASSERT_RULE

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ class MutableTerm final {
340340

341341
size_t size() const { return Atoms.size(); }
342342

343+
ArrayRef<const ProtocolDecl *> getRootProtocols() const;
344+
343345
decltype(Atoms)::const_iterator begin() const { return Atoms.begin(); }
344346
decltype(Atoms)::const_iterator end() const { return Atoms.end(); }
345347

0 commit comments

Comments
 (0)