Skip to content

Commit 84921db

Browse files
committed
Sema: Prefer abstract witnesses from the current protocol
A minimal Sequence conformance only needs to define an Iterator type, with the Element type witness inferred from the Element of the iterator. This trick didn't always work if the conforming type conformed to other protocols with declared same-type requirements involving Self.Element. Refine the heuristic introduced in 23599b6 to prefer abstract type witnesses in the current protocol, even if there is a shorter one in another protocol. Fixes rdar://problem/122574126, rdar://problem/122588328.
1 parent 2ed50ec commit 84921db

File tree

4 files changed

+215
-127
lines changed

4 files changed

+215
-127
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 88 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -791,27 +791,41 @@ class TypeWitnessSystem final {
791791
/// The int:
792792
/// - A flag indicating whether the resolved type is ambiguous. When set,
793793
/// the resolved type is null.
794-
llvm::PointerIntPair<Type, 1, bool> ResolvedTyAndIsAmbiguous;
794+
/// - A flag indicating whether the resolved type is 'preferred', meaning
795+
/// it came from the exact protocol we're checking conformance to.
796+
/// A preferred type takes precedence over a non-preferred type.
797+
llvm::PointerIntPair<Type, 2, unsigned> ResolvedTyAndFlags;
795798

796799
public:
797-
EquivalenceClass(Type ty) : ResolvedTyAndIsAmbiguous(ty, false) {}
800+
EquivalenceClass(Type ty, bool preferred)
801+
: ResolvedTyAndFlags(ty, preferred ? 2 : 0) {}
798802

799803
EquivalenceClass(const EquivalenceClass &) = delete;
800804
EquivalenceClass(EquivalenceClass &&) = delete;
801805
EquivalenceClass &operator=(const EquivalenceClass &) = delete;
802806
EquivalenceClass &operator=(EquivalenceClass &&) = delete;
803807

804808
Type getResolvedType() const {
805-
return ResolvedTyAndIsAmbiguous.getPointer();
809+
return ResolvedTyAndFlags.getPointer();
806810
}
807-
void setResolvedType(Type ty);
811+
void setResolvedType(Type ty, bool preferred);
808812

809813
bool isAmbiguous() const {
810-
return ResolvedTyAndIsAmbiguous.getInt();
814+
return (ResolvedTyAndFlags.getInt() & 1) != 0;
811815
}
812816
void setAmbiguous() {
813-
ResolvedTyAndIsAmbiguous = {nullptr, true};
817+
ResolvedTyAndFlags.setPointerAndInt(nullptr, 1);
814818
}
819+
820+
bool isPreferred() const {
821+
return (ResolvedTyAndFlags.getInt() & 2) != 0;
822+
}
823+
void setPreferred() {
824+
assert(!isAmbiguous());
825+
ResolvedTyAndFlags.setInt(ResolvedTyAndFlags.getInt() | 2);
826+
}
827+
828+
void dump(llvm::raw_ostream &out) const;
815829
};
816830

817831
/// A type witness candidate for a name variable.
@@ -852,21 +866,22 @@ class TypeWitnessSystem final {
852866
///
853867
/// \note This need not lead to the resolution of a type witness, e.g.
854868
/// an associated type may be defaulted to another.
855-
void addTypeWitness(Identifier name, Type type);
869+
void addTypeWitness(Identifier name, Type type, bool preferred);
856870

857871
/// Record a default type witness.
858872
///
859873
/// \param defaultedAssocType The specific associated type declaration that
860874
/// defines the given default type.
861875
///
862876
/// \note This need not lead to the resolution of a type witness.
863-
void addDefaultTypeWitness(Type type, AssociatedTypeDecl *defaultedAssocType);
877+
void addDefaultTypeWitness(Type type, AssociatedTypeDecl *defaultedAssocType,
878+
bool preferred);
864879

865880
/// Record the given same-type requirement, if regarded of interest to
866881
/// the system.
867882
///
868883
/// \note This need not lead to the resolution of a type witness.
869-
void addSameTypeRequirement(const Requirement &req);
884+
void addSameTypeRequirement(const Requirement &req, bool preferred);
870885

871886
void dump(llvm::raw_ostream &out,
872887
const NormalProtocolConformance *conformance) const;
@@ -898,7 +913,8 @@ class TypeWitnessSystem final {
898913

899914
/// Compare the given resolved types as targeting a single equivalence class,
900915
/// in terms of the their relative impact on solving the system.
901-
static ResolvedTypeComparisonResult compareResolvedTypes(Type ty1, Type ty2);
916+
static ResolvedTypeComparisonResult compareResolvedTypes(
917+
Type ty1, bool preferred1, Type ty2, bool preferred2);
902918
};
903919

904920
/// Captures the state needed to infer associated types.
@@ -2374,7 +2390,8 @@ void AssociatedTypeInference::collectAbstractTypeWitnesses(
23742390

23752391
if (gp->getName() == assocType->getName()) {
23762392
system.addTypeWitness(assocType->getName(),
2377-
dc->mapTypeIntoContext(gp));
2393+
dc->mapTypeIntoContext(gp),
2394+
/*preferred=*/true);
23782395
}
23792396
}
23802397
}
@@ -2394,10 +2411,14 @@ void AssociatedTypeInference::collectAbstractTypeWitnesses(
23942411

23952412
LLVM_DEBUG(llvm::dbgs() << "Collecting same-type requirements from "
23962413
<< conformedProto->getName() << "\n");
2414+
2415+
// Prefer abstract witnesses from the protocol of the current conformance;
2416+
// these are less likely to lead to request cycles.
2417+
bool preferred = (conformedProto == conformance->getProtocol());
23972418
for (const auto &req :
23982419
conformedProto->getRequirementSignature().getRequirements()) {
23992420
if (req.getKind() == RequirementKind::SameType)
2400-
system.addSameTypeRequirement(req);
2421+
system.addSameTypeRequirement(req, preferred);
24012422
}
24022423
};
24032424

@@ -2423,8 +2444,11 @@ void AssociatedTypeInference::collectAbstractTypeWitnesses(
24232444

24242445
// If we find a default type definition, feed it to the system.
24252446
if (const auto &typeWitness = computeDefaultTypeWitness(assocType)) {
2447+
bool preferred = (typeWitness->getDefaultedAssocType()->getDeclContext()
2448+
== conformance->getProtocol());
24262449
system.addDefaultTypeWitness(typeWitness->getType(),
2427-
typeWitness->getDefaultedAssocType());
2450+
typeWitness->getDefaultedAssocType(),
2451+
preferred);
24282452
}
24292453
}
24302454
}
@@ -3926,11 +3950,24 @@ auto AssociatedTypeInference::solve()
39263950
return llvm::None;
39273951
}
39283952

3929-
void TypeWitnessSystem::EquivalenceClass::setResolvedType(Type ty) {
3953+
void TypeWitnessSystem::EquivalenceClass::setResolvedType(Type ty, bool preferred) {
39303954
assert(ty && "cannot resolve to a null type");
39313955
assert(!isAmbiguous() && "must not set resolved type when ambiguous");
3956+
ResolvedTyAndFlags.setPointer(ty);
3957+
if (preferred)
3958+
setPreferred();
3959+
}
39323960

3933-
ResolvedTyAndIsAmbiguous.setPointer(ty);
3961+
void TypeWitnessSystem::EquivalenceClass::dump(llvm::raw_ostream &out) const {
3962+
if (auto resolvedType = getResolvedType()) {
3963+
out << resolvedType;
3964+
if (isPreferred())
3965+
out << " (preferred)";
3966+
} else if (isAmbiguous()) {
3967+
out << "(ambiguous)";
3968+
} else {
3969+
out << "(unresolved)";
3970+
}
39343971
}
39353972

39363973
TypeWitnessSystem::TypeWitnessSystem(
@@ -3967,7 +4004,8 @@ TypeWitnessSystem::getDefaultedAssocType(Identifier name) const {
39674004
return this->TypeWitnesses.lookup(name).DefaultedAssocType;
39684005
}
39694006

3970-
void TypeWitnessSystem::addTypeWitness(Identifier name, Type type) {
4007+
void TypeWitnessSystem::addTypeWitness(Identifier name, Type type,
4008+
bool preferred) {
39714009
assert(this->TypeWitnesses.count(name));
39724010

39734011
if (const auto *depTy = type->getAs<DependentMemberType>()) {
@@ -3999,11 +4037,13 @@ void TypeWitnessSystem::addTypeWitness(Identifier name, Type type) {
39994037
return;
40004038
}
40014039

4002-
// If we already have a resolved type, keep going only if the new one is
4003-
// a better choice.
40044040
const Type currResolvedTy = tyWitness.EquivClass->getResolvedType();
40054041
if (currResolvedTy) {
4006-
switch (compareResolvedTypes(type, currResolvedTy)) {
4042+
// If we already have a resolved type, keep going only if the new one is
4043+
// a better choice.
4044+
switch (compareResolvedTypes(type, preferred,
4045+
tyWitness.EquivClass->getResolvedType(),
4046+
tyWitness.EquivClass->isPreferred())) {
40074047
case ResolvedTypeComparisonResult::Better:
40084048
break;
40094049
case ResolvedTypeComparisonResult::EquivalentOrWorse:
@@ -4031,17 +4071,18 @@ void TypeWitnessSystem::addTypeWitness(Identifier name, Type type) {
40314071
}
40324072

40334073
if (tyWitness.EquivClass) {
4034-
tyWitness.EquivClass->setResolvedType(type);
4074+
tyWitness.EquivClass->setResolvedType(type, preferred);
40354075
} else {
4036-
auto *equivClass = new EquivalenceClass(type);
4076+
auto *equivClass = new EquivalenceClass(type, preferred);
40374077
this->EquivalenceClasses.insert(equivClass);
40384078

40394079
tyWitness.EquivClass = equivClass;
40404080
}
40414081
}
40424082

40434083
void TypeWitnessSystem::addDefaultTypeWitness(
4044-
Type type, AssociatedTypeDecl *defaultedAssocType) {
4084+
Type type, AssociatedTypeDecl *defaultedAssocType,
4085+
bool preferred) {
40454086
const auto name = defaultedAssocType->getName();
40464087
assert(this->TypeWitnesses.count(name));
40474088

@@ -4054,10 +4095,11 @@ void TypeWitnessSystem::addDefaultTypeWitness(
40544095
tyWitness.DefaultedAssocType = defaultedAssocType;
40554096

40564097
// Record the type witness.
4057-
addTypeWitness(name, type);
4098+
addTypeWitness(name, type, preferred);
40584099
}
40594100

4060-
void TypeWitnessSystem::addSameTypeRequirement(const Requirement &req) {
4101+
void TypeWitnessSystem::addSameTypeRequirement(const Requirement &req,
4102+
bool preferred) {
40614103
assert(req.getKind() == RequirementKind::SameType);
40624104

40634105
auto *const depTy1 = req.getFirstType()->getAs<DependentMemberType>();
@@ -4068,10 +4110,10 @@ void TypeWitnessSystem::addSameTypeRequirement(const Requirement &req) {
40684110
// the system.
40694111
if (depTy1 && depTy1->getBase()->is<GenericTypeParamType>() &&
40704112
this->TypeWitnesses.count(depTy1->getName())) {
4071-
addTypeWitness(depTy1->getName(), req.getSecondType());
4113+
addTypeWitness(depTy1->getName(), req.getSecondType(), preferred);
40724114
} else if (depTy2 && depTy2->getBase()->is<GenericTypeParamType>() &&
40734115
this->TypeWitnesses.count(depTy2->getName())) {
4074-
addTypeWitness(depTy2->getName(), req.getFirstType());
4116+
addTypeWitness(depTy2->getName(), req.getFirstType(), preferred);
40754117
}
40764118
}
40774119

@@ -4100,13 +4142,7 @@ void TypeWitnessSystem::dump(
41004142

41014143
const auto *eqClass = this->TypeWitnesses.lookup(name).EquivClass;
41024144
if (eqClass) {
4103-
if (eqClass->getResolvedType()) {
4104-
out << eqClass->getResolvedType();
4105-
} else if (eqClass->isAmbiguous()) {
4106-
out << "(ambiguous)";
4107-
} else {
4108-
out << "(unresolved)";
4109-
}
4145+
eqClass->dump(out);
41104146
} else {
41114147
out << "(unresolved)";
41124148
}
@@ -4143,7 +4179,7 @@ void TypeWitnessSystem::addEquivalence(Identifier name1, Identifier name2) {
41434179
tyWitness1.EquivClass = tyWitness2.EquivClass;
41444180
} else {
41454181
// Neither has an associated equivalence class.
4146-
auto *equivClass = new EquivalenceClass(nullptr);
4182+
auto *equivClass = new EquivalenceClass(nullptr, /*preferred=*/false);
41474183
this->EquivalenceClasses.insert(equivClass);
41484184

41494185
tyWitness1.EquivClass = equivClass;
@@ -4154,16 +4190,20 @@ void TypeWitnessSystem::addEquivalence(Identifier name1, Identifier name2) {
41544190
void TypeWitnessSystem::mergeEquivalenceClasses(
41554191
EquivalenceClass *equivClass1, const EquivalenceClass *equivClass2) {
41564192
assert(equivClass1 && equivClass2);
4193+
41574194
if (equivClass1 == equivClass2) {
41584195
return;
41594196
}
41604197

41614198
// Merge the second equivalence class into the first.
41624199
if (equivClass1->getResolvedType() && equivClass2->getResolvedType()) {
41634200
switch (compareResolvedTypes(equivClass2->getResolvedType(),
4164-
equivClass1->getResolvedType())) {
4201+
equivClass2->isPreferred(),
4202+
equivClass1->getResolvedType(),
4203+
equivClass1->isPreferred())) {
41654204
case ResolvedTypeComparisonResult::Better:
4166-
equivClass1->setResolvedType(equivClass2->getResolvedType());
4205+
equivClass1->setResolvedType(equivClass2->getResolvedType(),
4206+
equivClass2->isPreferred());
41674207
break;
41684208
case ResolvedTypeComparisonResult::EquivalentOrWorse:
41694209
break;
@@ -4175,7 +4215,8 @@ void TypeWitnessSystem::mergeEquivalenceClasses(
41754215
// Ambiguity is retained.
41764216
} else if (equivClass2->getResolvedType()) {
41774217
// Carry over the resolved type.
4178-
equivClass1->setResolvedType(equivClass2->getResolvedType());
4218+
equivClass1->setResolvedType(equivClass2->getResolvedType(),
4219+
equivClass2->isPreferred());
41794220
} else if (equivClass2->isAmbiguous()) {
41804221
// Carry over ambiguity.
41814222
equivClass1->setAmbiguous();
@@ -4194,12 +4235,20 @@ void TypeWitnessSystem::mergeEquivalenceClasses(
41944235
}
41954236

41964237
TypeWitnessSystem::ResolvedTypeComparisonResult
4197-
TypeWitnessSystem::compareResolvedTypes(Type ty1, Type ty2) {
4238+
TypeWitnessSystem::compareResolvedTypes(Type ty1, bool preferred1,
4239+
Type ty2, bool preferred2) {
41984240
assert(ty1 && ty2);
41994241

4200-
// Prefer shorter type parameters. This is just a heuristic and has no
4242+
// Prefer type parameters from our current protocol, then break a tie by
4243+
// applying the type parameter order. This is just a heuristic and has no
42014244
// theoretical basis at all.
42024245
if (ty1->isTypeParameter() && ty2->isTypeParameter()) {
4246+
if (preferred1 && !preferred2)
4247+
return ResolvedTypeComparisonResult::Better;
4248+
4249+
if (preferred2 && !preferred1)
4250+
return ResolvedTypeComparisonResult::EquivalentOrWorse;
4251+
42034252
return compareDependentTypes(ty1, ty2) < 0
42044253
? ResolvedTypeComparisonResult::Better
42054254
: ResolvedTypeComparisonResult::EquivalentOrWorse;

0 commit comments

Comments
 (0)