Skip to content

Commit 59e8043

Browse files
authored
Merge pull request swiftlang#34278 from xedin/transitive-protocol-inference
[CSBindings] Implement transtive protocol requirement inference
2 parents 86b7bac + 9598f19 commit 59e8043

File tree

5 files changed

+412
-50
lines changed

5 files changed

+412
-50
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class SolutionApplicationTarget;
6666

6767
} // end namespace constraints
6868

69+
namespace unittest {
70+
71+
class SemaTest;
72+
73+
} // end namespace unittest
74+
6975
// Forward declare some TypeChecker related functions
7076
// so they could be made friends of ConstraintSystem.
7177
namespace TypeChecker {
@@ -2018,6 +2024,8 @@ enum class SolutionApplicationToFunctionResult {
20182024
class ConstraintSystem {
20192025
ASTContext &Context;
20202026

2027+
friend class swift::unittest::SemaTest;
2028+
20212029
public:
20222030
DeclContext *DC;
20232031
ConstraintSystemOptions Options;
@@ -4705,7 +4713,11 @@ class ConstraintSystem {
47054713
SmallVector<PotentialBinding, 4> Bindings;
47064714

47074715
/// The set of protocol requirements placed on this type variable.
4708-
llvm::TinyPtrVector<Constraint *> Protocols;
4716+
llvm::SmallVector<Constraint *, 4> Protocols;
4717+
4718+
/// The set of transitive protocol requirements inferred through
4719+
/// subtype/conversion/equivalence relations with other type variables.
4720+
Optional<llvm::SmallPtrSet<Constraint *, 4>> TransitiveProtocols;
47094721

47104722
/// The set of constraints which would be used to infer default types.
47114723
llvm::TinyPtrVector<Constraint *> Defaults;
@@ -4740,15 +4752,13 @@ class ConstraintSystem {
47404752
/// Tracks the position of the last known supertype in the group.
47414753
Optional<unsigned> lastSupertypeIndex;
47424754

4743-
/// A set of all constraints which contribute to pontential bindings.
4744-
llvm::SmallPtrSet<Constraint *, 8> Sources;
4745-
47464755
/// A set of all not-yet-resolved type variables this type variable
4747-
/// is a subtype of. This is used to determine ordering inside a
4748-
/// chain of subtypes because binding inference algorithm can't,
4749-
/// at the moment, determine bindings transitively through supertype
4750-
/// type variables.
4751-
llvm::SmallPtrSet<TypeVariableType *, 4> SubtypeOf;
4756+
/// is a subtype of, supertype of or is equivalent to. This is used
4757+
/// to determine ordering inside of a chain of subtypes to help infer
4758+
/// transitive bindings and protocol requirements.
4759+
llvm::SmallMapVector<TypeVariableType *, Constraint *, 4> SubtypeOf;
4760+
llvm::SmallMapVector<TypeVariableType *, Constraint *, 4> SupertypeOf;
4761+
llvm::SmallMapVector<TypeVariableType *, Constraint *, 4> EquivalentTo;
47524762

47534763
PotentialBindings(TypeVariableType *typeVar)
47544764
: TypeVar(typeVar), PotentiallyIncomplete(isGenericParameter()) {}
@@ -4793,10 +4803,10 @@ class ConstraintSystem {
47934803
// This is required because algorithm can't currently infer
47944804
// bindings for subtype transitively through superclass ones.
47954805
if (!(x.IsHole && y.IsHole)) {
4796-
if (x.SubtypeOf.count(y.TypeVar))
4806+
if (x.isSubtypeOf(y.TypeVar))
47974807
return false;
47984808

4799-
if (y.SubtypeOf.count(x.TypeVar))
4809+
if (y.isSubtypeOf(x.TypeVar))
48004810
return true;
48014811
}
48024812

@@ -4842,6 +4852,15 @@ class ConstraintSystem {
48424852
return false;
48434853
}
48444854

4855+
bool isSubtypeOf(TypeVariableType *typeVar) const {
4856+
auto result = SubtypeOf.find(typeVar);
4857+
if (result == SubtypeOf.end())
4858+
return false;
4859+
4860+
auto *constraint = result->second;
4861+
return constraint->getKind() == ConstraintKind::Subtype;
4862+
}
4863+
48454864
/// Check if this binding is favored over a disjunction e.g.
48464865
/// if it has only concrete types or would resolve a closure.
48474866
bool favoredOverDisjunction(Constraint *disjunction) const;
@@ -4865,6 +4884,15 @@ class ConstraintSystem {
48654884
ConstraintSystem::PotentialBindings>
48664885
&inferredBindings);
48674886

4887+
/// Detect subtype, conversion or equivalence relationship
4888+
/// between two type variables and attempt to propagate protocol
4889+
/// requirements down the subtype or equivalence chain.
4890+
void inferTransitiveProtocolRequirements(
4891+
const ConstraintSystem &cs,
4892+
llvm::SmallDenseMap<TypeVariableType *,
4893+
ConstraintSystem::PotentialBindings>
4894+
&inferredBindings);
4895+
48684896
/// Infer bindings based on any protocol conformances that have default
48694897
/// types.
48704898
void inferDefaultTypes(ConstraintSystem &cs,
@@ -4878,8 +4906,8 @@ class ConstraintSystem {
48784906
/// Finalize binding computation for this type variable by
48794907
/// inferring bindings from context e.g. transitive bindings.
48804908
void finalize(ConstraintSystem &cs,
4881-
const llvm::SmallDenseMap<TypeVariableType *,
4882-
ConstraintSystem::PotentialBindings>
4909+
llvm::SmallDenseMap<TypeVariableType *,
4910+
ConstraintSystem::PotentialBindings>
48834911
&inferredBindings);
48844912

48854913
void dump(llvm::raw_ostream &out,

lib/Sema/CSBindings.cpp

Lines changed: 166 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,146 @@
2222
using namespace swift;
2323
using namespace constraints;
2424

25+
void ConstraintSystem::PotentialBindings::inferTransitiveProtocolRequirements(
26+
const ConstraintSystem &cs,
27+
llvm::SmallDenseMap<TypeVariableType *, ConstraintSystem::PotentialBindings>
28+
&inferredBindings) {
29+
if (TransitiveProtocols)
30+
return;
31+
32+
llvm::SmallVector<std::pair<TypeVariableType *, TypeVariableType *>, 4>
33+
workList;
34+
llvm::SmallPtrSet<TypeVariableType *, 4> visitedRelations;
35+
36+
llvm::SmallDenseMap<TypeVariableType *, SmallPtrSet<Constraint *, 4>, 4>
37+
protocols;
38+
39+
auto addToWorkList = [&](TypeVariableType *parent,
40+
TypeVariableType *typeVar) {
41+
if (visitedRelations.insert(typeVar).second)
42+
workList.push_back({parent, typeVar});
43+
};
44+
45+
auto propagateProtocolsTo =
46+
[&protocols](TypeVariableType *dstVar,
47+
const SmallVectorImpl<Constraint *> &direct,
48+
const SmallPtrSetImpl<Constraint *> &transitive) {
49+
auto &destination = protocols[dstVar];
50+
51+
for (auto *protocol : direct)
52+
destination.insert(protocol);
53+
54+
for (auto *protocol : transitive)
55+
destination.insert(protocol);
56+
};
57+
58+
addToWorkList(nullptr, TypeVar);
59+
60+
do {
61+
auto *currentVar = workList.back().second;
62+
63+
auto cachedBindings = inferredBindings.find(currentVar);
64+
if (cachedBindings == inferredBindings.end()) {
65+
workList.pop_back();
66+
continue;
67+
}
68+
69+
auto &bindings = cachedBindings->getSecond();
70+
71+
// If current variable already has transitive protocol
72+
// conformances inferred, there is no need to look deeper
73+
// into subtype/equivalence chain.
74+
if (bindings.TransitiveProtocols) {
75+
TypeVariableType *parent = nullptr;
76+
std::tie(parent, currentVar) = workList.pop_back_val();
77+
assert(parent);
78+
propagateProtocolsTo(parent, bindings.Protocols,
79+
*bindings.TransitiveProtocols);
80+
continue;
81+
}
82+
83+
for (const auto &entry : bindings.SubtypeOf)
84+
addToWorkList(currentVar, entry.first);
85+
86+
// If current type variable is part of an equivalence
87+
// class, make it a "representative" and let's it infer
88+
// supertypes and direct protocol requirements from
89+
// other members.
90+
for (const auto &entry : bindings.EquivalentTo) {
91+
auto eqBindings = inferredBindings.find(entry.first);
92+
if (eqBindings != inferredBindings.end()) {
93+
const auto &bindings = eqBindings->getSecond();
94+
95+
llvm::SmallPtrSet<Constraint *, 2> placeholder;
96+
// Add any direct protocols from members of the
97+
// equivalence class, so they could be propagated
98+
// to all of the members.
99+
propagateProtocolsTo(currentVar, bindings.Protocols, placeholder);
100+
101+
// Since type variables are equal, current type variable
102+
// becomes a subtype to any supertype found in the current
103+
// equivalence class.
104+
for (const auto &eqEntry : bindings.SubtypeOf)
105+
addToWorkList(currentVar, eqEntry.first);
106+
}
107+
}
108+
109+
// More subtype/equivalences relations have been added.
110+
if (workList.back().second != currentVar)
111+
continue;
112+
113+
TypeVariableType *parent = nullptr;
114+
std::tie(parent, currentVar) = workList.pop_back_val();
115+
116+
// At all of the protocols associated with current type variable
117+
// are transitive to its parent, propogate them down the subtype/equivalence
118+
// chain.
119+
if (parent) {
120+
propagateProtocolsTo(parent, bindings.Protocols, protocols[currentVar]);
121+
}
122+
123+
auto inferredProtocols = std::move(protocols[currentVar]);
124+
125+
llvm::SmallPtrSet<Constraint *, 4> protocolsForEquivalence;
126+
127+
// Equivalence class should contain both:
128+
// - direct protocol requirements of the current type
129+
// variable;
130+
// - all of the transitive protocols inferred through
131+
// the members of the equivalence class.
132+
{
133+
protocolsForEquivalence.insert(bindings.Protocols.begin(),
134+
bindings.Protocols.end());
135+
136+
protocolsForEquivalence.insert(inferredProtocols.begin(),
137+
inferredProtocols.end());
138+
}
139+
140+
// Propogate inferred protocols to all of the members of the
141+
// equivalence class.
142+
for (const auto &equivalence : bindings.EquivalentTo) {
143+
auto eqBindings = inferredBindings.find(equivalence.first);
144+
if (eqBindings != inferredBindings.end()) {
145+
auto &bindings = eqBindings->getSecond();
146+
bindings.TransitiveProtocols.emplace(protocolsForEquivalence);
147+
}
148+
}
149+
150+
// Update the bindings associated with current type variable,
151+
// to avoid repeating this inference process.
152+
bindings.TransitiveProtocols.emplace(std::move(inferredProtocols));
153+
} while (!workList.empty());
154+
}
155+
25156
void ConstraintSystem::PotentialBindings::inferTransitiveBindings(
26157
ConstraintSystem &cs, llvm::SmallPtrSetImpl<CanType> &existingTypes,
27158
const llvm::SmallDenseMap<TypeVariableType *,
28159
ConstraintSystem::PotentialBindings>
29160
&inferredBindings) {
30161
using BindingKind = ConstraintSystem::AllowedBindingKind;
31162

32-
llvm::SmallVector<Constraint *, 4> conversions;
33-
// First, let's collect all of the conversions associated
34-
// with this type variable.
35-
llvm::copy_if(
36-
Sources, std::back_inserter(conversions),
37-
[&](const Constraint *constraint) -> bool {
38-
if (constraint->getKind() != ConstraintKind::Subtype &&
39-
constraint->getKind() != ConstraintKind::Conversion &&
40-
constraint->getKind() != ConstraintKind::ArgumentConversion &&
41-
constraint->getKind() != ConstraintKind::OperatorArgumentConversion)
42-
return false;
43-
44-
auto rhs = cs.simplifyType(constraint->getSecondType());
45-
return rhs->getAs<TypeVariableType>() == TypeVar;
46-
});
47-
48-
for (auto *constraint : conversions) {
49-
auto *tv =
50-
cs.simplifyType(constraint->getFirstType())->getAs<TypeVariableType>();
51-
if (!tv || tv == TypeVar)
52-
continue;
53-
54-
auto relatedBindings = inferredBindings.find(tv);
163+
for (const auto &entry : SupertypeOf) {
164+
auto relatedBindings = inferredBindings.find(entry.first);
55165
if (relatedBindings == inferredBindings.end())
56166
continue;
57167

@@ -89,7 +199,7 @@ void ConstraintSystem::PotentialBindings::inferTransitiveBindings(
89199
llvm::copy(bindings.Defaults, std::back_inserter(Defaults));
90200

91201
// TODO: We shouldn't need this in the future.
92-
if (constraint->getKind() != ConstraintKind::Subtype)
202+
if (entry.second->getKind() != ConstraintKind::Subtype)
93203
continue;
94204

95205
for (auto &binding : bindings.Bindings) {
@@ -302,17 +412,16 @@ void ConstraintSystem::PotentialBindings::inferDefaultTypes(
302412

303413
void ConstraintSystem::PotentialBindings::finalize(
304414
ConstraintSystem &cs,
305-
const llvm::SmallDenseMap<TypeVariableType *,
306-
ConstraintSystem::PotentialBindings>
415+
llvm::SmallDenseMap<TypeVariableType *, ConstraintSystem::PotentialBindings>
307416
&inferredBindings) {
308417
// We need to make sure that there are no duplicate bindings in the
309418
// set, otherwise solver would produce multiple identical solutions.
310419
llvm::SmallPtrSet<CanType, 4> existingTypes;
311420
for (const auto &binding : Bindings)
312421
existingTypes.insert(binding.BindingType->getCanonicalType());
313422

423+
inferTransitiveProtocolRequirements(cs, inferredBindings);
314424
inferTransitiveBindings(cs, existingTypes, inferredBindings);
315-
316425
inferDefaultTypes(cs, existingTypes);
317426

318427
// Adjust optionality of existing bindings based on presence of
@@ -670,10 +779,6 @@ ConstraintSystem::getPotentialBindingForRelationalConstraint(
670779

671780
auto *typeVar = result.TypeVar;
672781

673-
// Record constraint which contributes to the
674-
// finding of potential bindings.
675-
result.Sources.insert(constraint);
676-
677782
auto first = simplifyType(constraint->getFirstType());
678783
auto second = simplifyType(constraint->getSecondType());
679784

@@ -790,9 +895,29 @@ ConstraintSystem::getPotentialBindingForRelationalConstraint(
790895
}
791896
}
792897

793-
if (constraint->getKind() == ConstraintKind::Subtype &&
794-
kind == AllowedBindingKind::Subtypes) {
795-
result.SubtypeOf.insert(bindingTypeVar);
898+
switch (constraint->getKind()) {
899+
case ConstraintKind::Subtype:
900+
case ConstraintKind::Conversion:
901+
case ConstraintKind::ArgumentConversion:
902+
case ConstraintKind::OperatorArgumentConversion: {
903+
if (kind == AllowedBindingKind::Subtypes) {
904+
result.SubtypeOf.insert({bindingTypeVar, constraint});
905+
} else {
906+
assert(kind == AllowedBindingKind::Supertypes);
907+
result.SupertypeOf.insert({bindingTypeVar, constraint});
908+
}
909+
break;
910+
}
911+
912+
case ConstraintKind::Bind:
913+
case ConstraintKind::BindParam:
914+
case ConstraintKind::Equal: {
915+
result.EquivalentTo.insert({bindingTypeVar, constraint});
916+
break;
917+
}
918+
919+
default:
920+
break;
796921
}
797922

798923
return None;
@@ -986,8 +1111,13 @@ bool ConstraintSystem::PotentialBindings::infer(
9861111
break;
9871112

9881113
case ConstraintKind::ConformsTo:
989-
case ConstraintKind::SelfObjectOfProtocol:
990-
return false;
1114+
case ConstraintKind::SelfObjectOfProtocol: {
1115+
auto protocolTy = constraint->getSecondType();
1116+
if (!protocolTy->is<ProtocolType>())
1117+
return false;
1118+
1119+
LLVM_FALLTHROUGH;
1120+
}
9911121

9921122
case ConstraintKind::LiteralConformsTo: {
9931123
// Record constraint where protocol requirement originated

0 commit comments

Comments
 (0)