Skip to content

Commit 56c96c6

Browse files
committed
[CSBindings] Use all equivalence chain members while interring transitive protocols
Currently inference logic only checked direct equivalence class members associated with a "work-in-progress" type variable, but each member can have local equivalences as well that need to be accounted for. Resolves: rdar://75978086
1 parent bae76c8 commit 56c96c6

File tree

3 files changed

+91
-19
lines changed

3 files changed

+91
-19
lines changed

lib/Sema/CSBindings.cpp

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -290,25 +290,51 @@ void BindingSet::inferTransitiveProtocolRequirements(
290290
// If current type variable is part of an equivalence
291291
// class, make it a "representative" and let it infer
292292
// supertypes and direct protocol requirements from
293-
// other members.
294-
for (const auto &entry : bindings.Info.EquivalentTo) {
295-
auto eqBindings = inferredBindings.find(entry.first);
296-
if (eqBindings != inferredBindings.end()) {
297-
const auto &bindings = eqBindings->getSecond();
298-
299-
llvm::SmallPtrSet<Constraint *, 2> placeholder;
300-
// Add any direct protocols from members of the
301-
// equivalence class, so they could be propagated
302-
// to all of the members.
303-
propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(),
304-
placeholder);
305-
306-
// Since type variables are equal, current type variable
307-
// becomes a subtype to any supertype found in the current
308-
// equivalence class.
309-
for (const auto &eqEntry : bindings.Info.SubtypeOf)
310-
addToWorkList(currentVar, eqEntry.first);
311-
}
293+
// other members and their equivalence classes.
294+
SmallSetVector<TypeVariableType *, 4> equivalenceClass;
295+
{
296+
SmallVector<TypeVariableType *, 4> workList;
297+
workList.push_back(currentVar);
298+
299+
do {
300+
auto *typeVar = workList.pop_back_val();
301+
302+
if (!equivalenceClass.insert(typeVar))
303+
continue;
304+
305+
auto bindingSet = inferredBindings.find(typeVar);
306+
if (bindingSet == inferredBindings.end())
307+
continue;
308+
309+
auto &equivalences = bindingSet->getSecond().Info.EquivalentTo;
310+
for (const auto &eqVar : equivalences) {
311+
workList.push_back(eqVar.first);
312+
}
313+
} while (!workList.empty());
314+
}
315+
316+
for (const auto &memberVar : equivalenceClass) {
317+
if (memberVar == currentVar)
318+
continue;
319+
320+
auto eqBindings = inferredBindings.find(memberVar);
321+
if (eqBindings == inferredBindings.end())
322+
continue;
323+
324+
const auto &bindings = eqBindings->getSecond();
325+
326+
llvm::SmallPtrSet<Constraint *, 2> placeholder;
327+
// Add any direct protocols from members of the
328+
// equivalence class, so they could be propagated
329+
// to all of the members.
330+
propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(),
331+
placeholder);
332+
333+
// Since type variables are equal, current type variable
334+
// becomes a subtype to any supertype found in the current
335+
// equivalence class.
336+
for (const auto &eqEntry : bindings.Info.SubtypeOf)
337+
addToWorkList(currentVar, eqEntry.first);
312338
}
313339

314340
// More subtype/equivalences relations have been added.

lib/Sema/ConstraintSystem.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,7 @@ ConstraintSystem::getTypeOfMemberReference(
16971697
// Concrete type replacing `Self` could be generic, so we need
16981698
// to make sure that it's opened before use.
16991699
baseOpenedTy = openType(concreteSelf, replacements);
1700+
baseObjTy = baseOpenedTy;
17001701
}
17011702
}
17021703
} else if (baseObjTy->isExistentialType()) {

unittests/Sema/BindingInferenceTests.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,48 @@ TEST_F(SemaTest, TestComplexTransitiveProtocolInference) {
300300
*bindingsForT5.TransitiveProtocols,
301301
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});
302302
}
303+
304+
/// Let's try a situation where there protocols are inferred from
305+
/// multiple sources on different levels of equivalence chain.
306+
///
307+
/// T0 = T1
308+
/// = T2 (P0)
309+
/// = T3 (P1)
310+
TEST_F(SemaTest, TestTransitiveProtocolInferenceThroughEquivalenceChains) {
311+
ConstraintSystemOptions options;
312+
ConstraintSystem cs(DC, options);
313+
314+
auto *protocolTy0 = createProtocol("P0");
315+
auto *protocolTy1 = createProtocol("P1");
316+
317+
auto *nilLocator = cs.getConstraintLocator({});
318+
319+
auto typeVar0 = cs.createTypeVariable(nilLocator, /*options=*/0);
320+
// Allow this type variable to be bound to l-value type to prevent
321+
// it from being merged with the rest of the type variables.
322+
auto typeVar1 =
323+
cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue);
324+
auto typeVar2 = cs.createTypeVariable(nilLocator, /*options=*/0);
325+
auto typeVar3 = cs.createTypeVariable(nilLocator, TVO_CanBindToLValue);
326+
327+
cs.addConstraint(ConstraintKind::Conversion, typeVar0, typeVar1, nilLocator);
328+
cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar2, nilLocator);
329+
cs.addConstraint(ConstraintKind::Equal, typeVar2, typeVar3, nilLocator);
330+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar2, protocolTy0, nilLocator);
331+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar3, protocolTy1, nilLocator);
332+
333+
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
334+
for (auto *typeVar : cs.getTypeVariables()) {
335+
cache.insert({typeVar, cs.getBindingsFor(typeVar, /*finalize=*/false)});
336+
}
337+
338+
auto bindingSet = cache.find(typeVar0);
339+
assert(bindingSet != cache.end());
340+
341+
auto &bindings = bindingSet->getSecond();
342+
bindings.inferTransitiveProtocolRequirements(cache);
343+
344+
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
345+
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
346+
{protocolTy0, protocolTy1});
347+
}

0 commit comments

Comments
 (0)