Skip to content

Commit 0d0bcb2

Browse files
committed
RequirementMachine: Simplify the Symbol API for removal of merged associated types
1 parent e2e088e commit 0d0bcb2

13 files changed

+84
-236
lines changed

lib/AST/RequirementMachine/ConcreteTypeWitness.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ void PropertyMap::concretizeNestedTypesFromConcreteParent(
200200

201201
// We only infer conditional requirements in top-level generic signatures,
202202
// not in protocol requirement signatures.
203-
if (key.getRootProtocols().empty())
203+
if (key.getRootProtocol() == nullptr)
204204
inferConditionalRequirements(concrete, substitutions);
205205
}
206206
}

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,10 @@ RequirementMachine::getLongestValidPrefix(const MutableTerm &term) const {
227227

228228
auto conformsTo = props->getConformsTo();
229229

230-
for (const auto *proto : symbol.getProtocols()) {
231-
// T.[P:A] is valid iff T conforms to P.
232-
if (std::find(conformsTo.begin(), conformsTo.end(), proto)
233-
== conformsTo.end())
234-
return prefix;
235-
}
230+
// T.[P:A] is valid iff T conforms to P.
231+
if (std::find(conformsTo.begin(), conformsTo.end(), symbol.getProtocol())
232+
== conformsTo.end())
233+
return prefix;
236234

237235
break;
238236
}
@@ -720,7 +718,7 @@ void RequirementMachine::verify(const MutableTerm &term) const {
720718
continue;
721719

722720
case Symbol::Kind::AssociatedType:
723-
erased.add(Symbol::forProtocol(symbol.getProtocols()[0], Context));
721+
erased.add(Symbol::forProtocol(symbol.getProtocol(), Context));
724722
break;
725723

726724
case Symbol::Kind::Name:

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,7 @@ RewriteSystem::getMinimizedProtocolRules() const {
651651
continue;
652652
}
653653

654-
auto domain = rule.getLHS()[0].getProtocols();
655-
assert(domain.size() == 1);
656-
657-
const auto *proto = domain[0];
654+
const auto *proto = rule.getLHS().begin()->getProtocol();
658655
if (std::find(Protos.begin(), Protos.end(), proto) != Protos.end())
659656
rules[proto].push_back(ruleID);
660657
}
@@ -734,7 +731,7 @@ void RewriteSystem::verifyMinimizedRules(
734731
const auto &rule = getRule(ruleID);
735732

736733
// Ignore the rewrite rule if it is not part of our minimization domain.
737-
if (!isInMinimizationDomain(rule.getLHS().getRootProtocols()))
734+
if (!isInMinimizationDomain(rule.getLHS().getRootProtocol()))
738735
continue;
739736

740737
// Note that sometimes permanent rules can be simplified, but they can never

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,26 @@ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(Symbol symbol) {
177177
// The associated type An' is then the canonical associated type
178178
// representative of the associated type symbol [P0&...&Pn:A].
179179
//
180-
for (auto *proto : symbol.getProtocols()) {
181-
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
182-
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
180+
auto *proto = symbol.getProtocol();
183181

184-
if (otherAssocType->getName() == name &&
185-
(assocType == nullptr ||
186-
TypeDecl::compare(otherAssocType->getProtocol(),
187-
assocType->getProtocol()) < 0)) {
188-
assocType = otherAssocType;
189-
}
190-
};
182+
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
183+
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
191184

192-
for (auto *otherAssocType : proto->getAssociatedTypeMembers()) {
193-
checkOtherAssocType(otherAssocType);
185+
if (otherAssocType->getName() == name &&
186+
(assocType == nullptr ||
187+
TypeDecl::compare(otherAssocType->getProtocol(),
188+
assocType->getProtocol()) < 0)) {
189+
assocType = otherAssocType;
194190
}
191+
};
195192

196-
for (auto *inheritedProto : getInheritedProtocols(proto)) {
197-
for (auto *otherAssocType : inheritedProto->getAssociatedTypeMembers()) {
198-
checkOtherAssocType(otherAssocType);
199-
}
193+
for (auto *otherAssocType : proto->getAssociatedTypeMembers()) {
194+
checkOtherAssocType(otherAssocType);
195+
}
196+
197+
for (auto *inheritedProto : getInheritedProtocols(proto)) {
198+
for (auto *otherAssocType : inheritedProto->getAssociatedTypeMembers()) {
199+
checkOtherAssocType(otherAssocType);
200200
}
201201
}
202202

@@ -318,8 +318,8 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
318318
#ifndef NDEBUG
319319
// Ensure that the domain of the suffix contains P.
320320
if (iter + 1 < end) {
321-
auto protos = (iter + 1)->getProtocols();
322-
assert(std::find(protos.begin(), protos.end(), symbol.getProtocol()));
321+
auto proto = (iter + 1)->getProtocol();
322+
assert(proto == symbol.getProtocol());
323323
}
324324
#endif
325325
continue;
@@ -350,10 +350,9 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
350350
// of protocols that the prefix conforms to.
351351
#ifndef NDEBUG
352352
auto conformsTo = props->getConformsTo();
353-
for (auto *otherProto : symbol.getProtocols()) {
354-
assert(std::find(conformsTo.begin(), conformsTo.end(), otherProto)
355-
!= conformsTo.end());
356-
}
353+
assert(std::find(conformsTo.begin(), conformsTo.end(),
354+
symbol.getProtocol())
355+
!= conformsTo.end());
357356
#endif
358357

359358
assocType = props->getAssociatedType(symbol.getName());

lib/AST/RequirementMachine/MinimalConformances.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,7 @@ MinimalConformances::decomposeTermIntoConformanceRuleLeftHandSides(
309309

310310
// Compute domain(V).
311311
const auto &lhs = rule.getLHS();
312-
auto protocols = lhs[0].getProtocols();
313-
assert(protocols.size() == 1);
314-
auto protocol = Symbol::forProtocol(protocols[0], Context);
312+
auto protocol = Symbol::forProtocol(lhs[0].getProtocol(), Context);
315313

316314
// A same-type requirement of the form 'Self.Foo == Self' can induce a
317315
// conformance rule [P].[P] => [P], and we can end up with a minimal
@@ -356,10 +354,7 @@ static const ProtocolDecl *getParentConformanceForTerm(Term lhs) {
356354

357355
// If we have a rule of the form X.[P:Y].[Q] => X.[P:Y] wih non-empty X,
358356
// then the parent type is X.[P].
359-
const auto protos = parentSymbol.getProtocols();
360-
assert(protos.size() == 1);
361-
362-
return protos[0];
357+
return parentSymbol.getProtocol();
363358
}
364359

365360
case Symbol::Kind::GenericParam:
@@ -401,7 +396,7 @@ void MinimalConformances::collectConformanceRules() {
401396
if (!rule.isAnyConformanceRule())
402397
continue;
403398

404-
if (!System.isInMinimizationDomain(rule.getLHS().getRootProtocols()))
399+
if (!System.isInMinimizationDomain(rule.getLHS().getRootProtocol()))
405400
continue;
406401

407402
ConformanceRules.push_back(ruleID);

lib/AST/RequirementMachine/PropertyRelations.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ unsigned RewriteSystem::recordConcreteTypeWitnessRelation(
101101
Symbol::Kind::ConcreteConformance);
102102
assert(associatedTypeSymbol.getKind() ==
103103
Symbol::Kind::AssociatedType);
104-
assert(associatedTypeSymbol.getProtocols().size() == 1);
105104
assert(concreteConformanceSymbol.getProtocol() ==
106-
associatedTypeSymbol.getProtocols()[0]);
105+
associatedTypeSymbol.getProtocol());
107106
assert(typeWitnessSymbol.getKind() == Symbol::Kind::ConcreteType);
108107

109108
MutableTerm rhsTerm;

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -99,39 +99,10 @@ RewriteContext::getInheritedProtocols(const ProtocolDecl *proto) {
9999
return result;
100100
}
101101

102-
unsigned RewriteContext::getProtocolSupport(
103-
const ProtocolDecl *proto) {
104-
return getInheritedProtocols(proto).size() + 1;
105-
}
106-
107-
unsigned RewriteContext::getProtocolSupport(
108-
ArrayRef<const ProtocolDecl *> protos) {
109-
auto found = Support.find(protos);
110-
if (found != Support.end())
111-
return found->second;
112-
113-
unsigned result;
114-
if (protos.size() == 1) {
115-
result = getProtocolSupport(protos[0]);
116-
} else {
117-
llvm::DenseSet<const ProtocolDecl *> visited;
118-
for (const auto *proto : protos) {
119-
visited.insert(proto);
120-
for (const auto *inheritedProto : getInheritedProtocols(proto))
121-
visited.insert(inheritedProto);
122-
}
123-
124-
result = visited.size();
125-
}
126-
127-
Support[protos] = result;
128-
return result;
129-
}
130-
131102
int RewriteContext::compareProtocols(const ProtocolDecl *lhs,
132103
const ProtocolDecl *rhs) {
133-
unsigned lhsSupport = getProtocolSupport(lhs);
134-
unsigned rhsSupport = getProtocolSupport(rhs);
104+
unsigned lhsSupport = getInheritedProtocols(lhs).size();
105+
unsigned rhsSupport = getInheritedProtocols(rhs).size();
135106

136107
if (lhsSupport != rhsSupport)
137108
return rhsSupport - lhsSupport;

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ class RewriteContext final {
5252
llvm::DenseMap<const ProtocolDecl *,
5353
llvm::TinyPtrVector<const ProtocolDecl *>> AllInherited;
5454

55-
/// Cached support of sets of protocols, which is the number of elements in
56-
/// the transitive closure of the set under protocol inheritance.
57-
llvm::DenseMap<ArrayRef<const ProtocolDecl *>, unsigned> Support;
58-
5955
/// Cache for associated type declarations.
6056
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;
6157

@@ -148,10 +144,6 @@ class RewriteContext final {
148144
const llvm::TinyPtrVector<const ProtocolDecl *> &
149145
getInheritedProtocols(const ProtocolDecl *proto);
150146

151-
unsigned getProtocolSupport(const ProtocolDecl *proto);
152-
153-
unsigned getProtocolSupport(ArrayRef<const ProtocolDecl *> protos);
154-
155147
int compareProtocols(const ProtocolDecl *lhs,
156148
const ProtocolDecl *rhs);
157149

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -647,15 +647,13 @@ void RewriteSystem::simplifyLeftHandSideSubstitutions() {
647647
///
648648
/// All other loops can be discarded since they do not encode redundancies
649649
/// that are relevant to us.
650-
bool RewriteSystem::isInMinimizationDomain(
651-
ArrayRef<const ProtocolDecl *> protos) const {
652-
assert(protos.size() <= 1);
653-
assert(Protos.empty() || !protos.empty());
650+
bool RewriteSystem::isInMinimizationDomain(const ProtocolDecl *proto) const {
651+
assert(Protos.empty() || proto != nullptr);
654652

655-
if (protos.empty() && Protos.empty())
653+
if (proto == nullptr && Protos.empty())
656654
return true;
657655

658-
if (std::find(Protos.begin(), Protos.end(), protos[0]) != Protos.end())
656+
if (std::find(Protos.begin(), Protos.end(), proto) != Protos.end())
659657
return true;
660658

661659
return false;
@@ -670,7 +668,7 @@ void RewriteSystem::recordRewriteLoop(MutableTerm basepoint,
670668
return;
671669

672670
// Ignore the rewrite rule if it is not part of our minimization domain.
673-
if (!isInMinimizationDomain(basepoint.getRootProtocols()))
671+
if (!isInMinimizationDomain(basepoint.getRootProtocol()))
674672
return;
675673

676674
Loops.push_back(loop);
@@ -740,8 +738,8 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
740738
}
741739
}
742740

743-
auto lhsDomain = lhs.getRootProtocols();
744-
auto rhsDomain = rhs.getRootProtocols();
741+
auto lhsDomain = lhs.getRootProtocol();
742+
auto rhsDomain = rhs.getRootProtocol();
745743

746744
ASSERT_RULE(lhsDomain == rhsDomain);
747745
}

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ class RewriteSystem final {
452452
void recordRewriteLoop(MutableTerm basepoint,
453453
RewritePath path);
454454

455-
bool isInMinimizationDomain(ArrayRef<const ProtocolDecl *> protos) const;
455+
bool isInMinimizationDomain(const ProtocolDecl *proto) const;
456456

457457
ArrayRef<RewriteLoop> getLoops() const {
458458
return Loops;

0 commit comments

Comments
 (0)