Skip to content

Commit 8e401d2

Browse files
authored
Merge pull request swiftlang#36657 from xedin/rdar-75978086
[CSBindings] A couple of adjustments to transitive protocol inference
2 parents eaa2709 + d310f37 commit 8e401d2

File tree

5 files changed

+110
-25
lines changed

5 files changed

+110
-25
lines changed

lib/Sema/CSBindings.cpp

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -290,25 +290,51 @@ void BindingSet::inferTransitiveProtocolRequirements(
290290
// If current type variable is part of an equivalence
291291
// class, make it a "representative" and let it infer
292292
// supertypes and direct protocol requirements from
293-
// other members.
294-
for (const auto &entry : bindings.Info.EquivalentTo) {
295-
auto eqBindings = inferredBindings.find(entry.first);
296-
if (eqBindings != inferredBindings.end()) {
297-
const auto &bindings = eqBindings->getSecond();
298-
299-
llvm::SmallPtrSet<Constraint *, 2> placeholder;
300-
// Add any direct protocols from members of the
301-
// equivalence class, so they could be propagated
302-
// to all of the members.
303-
propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(),
304-
placeholder);
305-
306-
// Since type variables are equal, current type variable
307-
// becomes a subtype to any supertype found in the current
308-
// equivalence class.
309-
for (const auto &eqEntry : bindings.Info.SubtypeOf)
310-
addToWorkList(currentVar, eqEntry.first);
311-
}
293+
// other members and their equivalence classes.
294+
SmallSetVector<TypeVariableType *, 4> equivalenceClass;
295+
{
296+
SmallVector<TypeVariableType *, 4> workList;
297+
workList.push_back(currentVar);
298+
299+
do {
300+
auto *typeVar = workList.pop_back_val();
301+
302+
if (!equivalenceClass.insert(typeVar))
303+
continue;
304+
305+
auto bindingSet = inferredBindings.find(typeVar);
306+
if (bindingSet == inferredBindings.end())
307+
continue;
308+
309+
auto &equivalences = bindingSet->getSecond().Info.EquivalentTo;
310+
for (const auto &eqVar : equivalences) {
311+
workList.push_back(eqVar.first);
312+
}
313+
} while (!workList.empty());
314+
}
315+
316+
for (const auto &memberVar : equivalenceClass) {
317+
if (memberVar == currentVar)
318+
continue;
319+
320+
auto eqBindings = inferredBindings.find(memberVar);
321+
if (eqBindings == inferredBindings.end())
322+
continue;
323+
324+
const auto &bindings = eqBindings->getSecond();
325+
326+
llvm::SmallPtrSet<Constraint *, 2> placeholder;
327+
// Add any direct protocols from members of the
328+
// equivalence class, so they could be propagated
329+
// to all of the members.
330+
propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(),
331+
placeholder);
332+
333+
// Since type variables are equal, current type variable
334+
// becomes a subtype to any supertype found in the current
335+
// equivalence class.
336+
for (const auto &eqEntry : bindings.Info.SubtypeOf)
337+
addToWorkList(currentVar, eqEntry.first);
312338
}
313339

314340
// More subtype/equivalences relations have been added.
@@ -435,7 +461,6 @@ void BindingSet::inferTransitiveBindings(
435461

436462
void BindingSet::finalize(
437463
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
438-
inferTransitiveProtocolRequirements(inferredBindings);
439464
inferTransitiveBindings(inferredBindings);
440465

441466
determineLiteralCoverage();
@@ -452,11 +477,14 @@ void BindingSet::finalize(
452477
// func foo<T: P>(_: T) {}
453478
// foo(.bar) <- `.bar` should be a static member of `P`.
454479
// \endcode
455-
if (!hasViableBindings() && TransitiveProtocols.hasValue()) {
456-
for (auto *constraint : *TransitiveProtocols) {
457-
auto protocolTy = constraint->getSecondType();
458-
addBinding(
459-
{protocolTy, AllowedBindingKind::Exact, constraint});
480+
if (!hasViableBindings()) {
481+
inferTransitiveProtocolRequirements(inferredBindings);
482+
483+
if (TransitiveProtocols.hasValue()) {
484+
for (auto *constraint : *TransitiveProtocols) {
485+
auto protocolTy = constraint->getSecondType();
486+
addBinding({protocolTy, AllowedBindingKind::Exact, constraint});
487+
}
460488
}
461489
}
462490
}

lib/Sema/ConstraintSystem.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,7 @@ ConstraintSystem::getTypeOfMemberReference(
17101710
// Concrete type replacing `Self` could be generic, so we need
17111711
// to make sure that it's opened before use.
17121712
baseOpenedTy = openType(concreteSelf, replacements);
1713+
baseObjTy = baseOpenedTy;
17131714
}
17141715
}
17151716
} else if (baseObjTy->isExistentialType()) {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %target-swift-frontend -disable-availability-checking -typecheck -verify %s
2+
3+
// rdar://75978086 - static member lookup doesn't work with opaque types
4+
5+
protocol Intent {}
6+
7+
extension Intent where Self == Intents.List {
8+
static func orderedList() -> Self {
9+
return Intents.List.orderedList(nestedIn: nil)
10+
}
11+
}
12+
13+
enum Intents {
14+
enum List: Intent {
15+
case orderedList(nestedIn: Any?)
16+
}
17+
}
18+
19+
let rdar75978086: some Intent = .orderedList() // Ok

unittests/Sema/BindingInferenceTests.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,39 @@ TEST_F(SemaTest, TestComplexTransitiveProtocolInference) {
300300
*bindingsForT5.TransitiveProtocols,
301301
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});
302302
}
303+
304+
/// Let's try a situation where there protocols are inferred from
305+
/// multiple sources on different levels of equivalence chain.
306+
///
307+
/// T0 = T1
308+
/// = T2 (P0)
309+
/// = T3 (P1)
310+
TEST_F(SemaTest, TestTransitiveProtocolInferenceThroughEquivalenceChains) {
311+
ConstraintSystemOptions options;
312+
ConstraintSystem cs(DC, options);
313+
314+
auto *protocolTy0 = createProtocol("P0");
315+
auto *protocolTy1 = createProtocol("P1");
316+
317+
auto *nilLocator = cs.getConstraintLocator({});
318+
319+
auto typeVar0 = cs.createTypeVariable(nilLocator, /*options=*/0);
320+
// Allow this type variable to be bound to l-value type to prevent
321+
// it from being merged with the rest of the type variables.
322+
auto typeVar1 =
323+
cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue);
324+
auto typeVar2 = cs.createTypeVariable(nilLocator, /*options=*/0);
325+
auto typeVar3 = cs.createTypeVariable(nilLocator, TVO_CanBindToLValue);
326+
327+
cs.addConstraint(ConstraintKind::Conversion, typeVar0, typeVar1, nilLocator);
328+
cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar2, nilLocator);
329+
cs.addConstraint(ConstraintKind::Equal, typeVar2, typeVar3, nilLocator);
330+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar2, protocolTy0, nilLocator);
331+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar3, protocolTy1, nilLocator);
332+
333+
auto bindings = inferBindings(cs, typeVar0);
334+
335+
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
336+
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
337+
{protocolTy0, protocolTy1});
338+
}

unittests/Sema/SemaFixture.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ BindingSet SemaTest::inferBindings(ConstraintSystem &cs,
136136
continue;
137137

138138
auto &bindings = cachedBindings->getSecond();
139+
bindings.inferTransitiveProtocolRequirements(cache);
139140
bindings.finalize(cache);
140141
}
141142

0 commit comments

Comments
 (0)