Skip to content

Commit d011bf3

Browse files
committed
[CSBindings] Implement transtive protocol requirement inference
Implements iterative protocol requirement inference through subtype, conversion and equivalence relationships. This algorithm doesn't depend on a type variable finalization order (which is currently the order of type variable introduction). If a given type variable doesn't yet have its transitive protocol requirements inferred, algorithm would use iterative depth-first walk through its supertypes and equivalences and incrementally infer transitive protocols for each type variable involved, transferring new information down the chain e.g. T1 T3 \ / T4 T5 \ / T2 Here `T1`, `T3` are supertypes of `T4`, `T4` and `T5` are supertypes of `T2`. Let's assume that algorithm starts at `T2` and none of the involved type variables have their protocol requirements inferred yet. First, it would consider supertypes of `T2` which are `T4` and `T5`, since `T5` is the last in the chain algorithm would transfer its direct protocol requirements to `T2`. `T4` has supertypes `T1` and `T3` - they transfer their direct protocol requirements to `T4` and `T4` transfers its direct and transitive (from `T1` and `T3`) protocol requirements to `T2`. At this point all the type variables in subtype chain have their transitive protocol requirements resolved and cached so they don't have to be re-inferred later.
1 parent a9cce60 commit d011bf3

File tree

2 files changed

+101
-5
lines changed

2 files changed

+101
-5
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4708,6 +4708,10 @@ class ConstraintSystem {
47084708
/// The set of protocol requirements placed on this type variable.
47094709
llvm::SmallVector<Constraint *, 4> Protocols;
47104710

4711+
/// The set of transitive protocol requirements inferred through
4712+
/// subtype/conversion/equivalence relations with other type variables.
4713+
Optional<llvm::SmallPtrSet<Constraint *, 4>> TransitiveProtocols;
4714+
47114715
/// The set of constraints which would be used to infer default types.
47124716
llvm::TinyPtrVector<Constraint *> Defaults;
47134717

@@ -4873,6 +4877,15 @@ class ConstraintSystem {
48734877
ConstraintSystem::PotentialBindings>
48744878
&inferredBindings);
48754879

4880+
/// Detect subtype, conversion or equivalence relationship
4881+
/// between two type variables and attempt to propagate protocol
4882+
/// requirements down the subtype or equivalence chain.
4883+
void inferTransitiveProtocolRequirements(
4884+
const ConstraintSystem &cs,
4885+
llvm::SmallDenseMap<TypeVariableType *,
4886+
ConstraintSystem::PotentialBindings>
4887+
&inferredBindings);
4888+
48764889
/// Infer bindings based on any protocol conformances that have default
48774890
/// types.
48784891
void inferDefaultTypes(ConstraintSystem &cs,
@@ -4886,8 +4899,8 @@ class ConstraintSystem {
48864899
/// Finalize binding computation for this type variable by
48874900
/// inferring bindings from context e.g. transitive bindings.
48884901
void finalize(ConstraintSystem &cs,
4889-
const llvm::SmallDenseMap<TypeVariableType *,
4890-
ConstraintSystem::PotentialBindings>
4902+
llvm::SmallDenseMap<TypeVariableType *,
4903+
ConstraintSystem::PotentialBindings>
48914904
&inferredBindings);
48924905

48934906
void dump(llvm::raw_ostream &out,

lib/Sema/CSBindings.cpp

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,90 @@
2222
using namespace swift;
2323
using namespace constraints;
2424

25+
void ConstraintSystem::PotentialBindings::inferTransitiveProtocolRequirements(
26+
const ConstraintSystem &cs,
27+
llvm::SmallDenseMap<TypeVariableType *, ConstraintSystem::PotentialBindings>
28+
&inferredBindings) {
29+
if (TransitiveProtocols)
30+
return;
31+
32+
llvm::SmallVector<std::pair<TypeVariableType *, TypeVariableType *>, 4>
33+
workList;
34+
llvm::SmallPtrSet<TypeVariableType *, 4> visitedRelations;
35+
36+
llvm::SmallDenseMap<TypeVariableType *, SmallPtrSet<Constraint *, 4>, 4>
37+
protocols;
38+
39+
auto addToWorkList = [&](TypeVariableType *parent,
40+
TypeVariableType *typeVar) {
41+
if (visitedRelations.insert(typeVar).second)
42+
workList.push_back({parent, typeVar});
43+
};
44+
45+
auto propagateProtocolsTo =
46+
[&protocols](TypeVariableType *dstVar,
47+
const SmallVectorImpl<Constraint *> &direct,
48+
const SmallPtrSetImpl<Constraint *> &transitive) {
49+
auto &destination = protocols[dstVar];
50+
51+
for (auto *protocol : direct)
52+
destination.insert(protocol);
53+
54+
for (auto *protocol : transitive)
55+
destination.insert(protocol);
56+
};
57+
58+
addToWorkList(nullptr, TypeVar);
59+
60+
do {
61+
auto *currentVar = workList.back().second;
62+
63+
auto cachedBindings = inferredBindings.find(currentVar);
64+
if (cachedBindings == inferredBindings.end()) {
65+
workList.pop_back();
66+
continue;
67+
}
68+
69+
auto &bindings = cachedBindings->getSecond();
70+
71+
// If current variable already has transitive protocol
72+
// conformances inferred, there is no need to look deeper
73+
// into subtype/equivalence chain.
74+
if (bindings.TransitiveProtocols) {
75+
TypeVariableType *parent = nullptr;
76+
std::tie(parent, currentVar) = workList.pop_back_val();
77+
assert(parent);
78+
propagateProtocolsTo(parent, bindings.Protocols,
79+
*bindings.TransitiveProtocols);
80+
continue;
81+
}
82+
83+
for (const auto &entry : bindings.SubtypeOf)
84+
addToWorkList(currentVar, entry.first);
85+
86+
for (const auto &entry : bindings.EquivalentTo)
87+
addToWorkList(currentVar, entry.first);
88+
89+
// More subtype/equivalences relations have been added.
90+
if (workList.back().second != currentVar)
91+
continue;
92+
93+
TypeVariableType *parent = nullptr;
94+
std::tie(parent, currentVar) = workList.pop_back_val();
95+
96+
// At all of the protocols associated with current type variable
97+
// are transitive to its parent, propogate them down the subtype/equivalence
98+
// chain.
99+
if (parent) {
100+
propagateProtocolsTo(parent, bindings.Protocols, protocols[currentVar]);
101+
}
102+
103+
// Update the bindings associated with current type variable,
104+
// to avoid repeating this inference process.
105+
bindings.TransitiveProtocols.emplace(std::move(protocols[currentVar]));
106+
} while (!workList.empty());
107+
}
108+
25109
void ConstraintSystem::PotentialBindings::inferTransitiveBindings(
26110
ConstraintSystem &cs, llvm::SmallPtrSetImpl<CanType> &existingTypes,
27111
const llvm::SmallDenseMap<TypeVariableType *,
@@ -281,17 +365,16 @@ void ConstraintSystem::PotentialBindings::inferDefaultTypes(
281365

282366
void ConstraintSystem::PotentialBindings::finalize(
283367
ConstraintSystem &cs,
284-
const llvm::SmallDenseMap<TypeVariableType *,
285-
ConstraintSystem::PotentialBindings>
368+
llvm::SmallDenseMap<TypeVariableType *, ConstraintSystem::PotentialBindings>
286369
&inferredBindings) {
287370
// We need to make sure that there are no duplicate bindings in the
288371
// set, otherwise solver would produce multiple identical solutions.
289372
llvm::SmallPtrSet<CanType, 4> existingTypes;
290373
for (const auto &binding : Bindings)
291374
existingTypes.insert(binding.BindingType->getCanonicalType());
292375

376+
inferTransitiveProtocolRequirements(cs, inferredBindings);
293377
inferTransitiveBindings(cs, existingTypes, inferredBindings);
294-
295378
inferDefaultTypes(cs, existingTypes);
296379

297380
// Adjust optionality of existing bindings based on presence of

0 commit comments

Comments
 (0)