Skip to content

Commit 9598f19

Browse files
committed
[unittest/Sema] Cover transitive protocol inference with unit tests
1 parent a3c3981 commit 9598f19

File tree

1 file changed

+131
-15
lines changed

1 file changed

+131
-15
lines changed

unittests/Sema/BindingInferenceTests.cpp

Lines changed: 131 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,41 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
4646
ASSERT_TRUE(binding.hasDefaultedLiteralProtocol());
4747
}
4848

49+
// Given a set of inferred protocol requirements, make sure that
50+
// all of the expected types are present.
51+
static void verifyProtocolInferenceResults(
52+
const llvm::SmallPtrSetImpl<Constraint *> &protocols,
53+
ArrayRef<Type> expectedTypes) {
54+
ASSERT_TRUE(protocols.size() >= expectedTypes.size());
55+
56+
llvm::SmallPtrSet<Type, 2> inferredProtocolTypes;
57+
for (auto *protocol : protocols)
58+
inferredProtocolTypes.insert(protocol->getSecondType());
59+
60+
for (auto expectedTy : expectedTypes) {
61+
ASSERT_TRUE(inferredProtocolTypes.count(expectedTy));
62+
}
63+
}
64+
4965
TEST_F(SemaTest, TestTransitiveProtocolInference) {
5066
ConstraintSystemOptions options;
5167
ConstraintSystem cs(DC, options);
5268

53-
auto *PD1 =
54-
new (Context) ProtocolDecl(DC, SourceLoc(), SourceLoc(),
55-
Context.getIdentifier("P1"), /*Inherited=*/{},
56-
/*trailingWhere=*/nullptr);
57-
PD1->setImplicit();
69+
auto *protocolTy1 = createProtocol("P1");
70+
auto *protocolTy2 = createProtocol("P2");
5871

59-
auto *protocolTy1 = ProtocolType::get(PD1, Type(), Context);
72+
auto *GPT1 = cs.createTypeVariable(cs.getConstraintLocator({}),
73+
/*options=*/TVO_CanBindToNoEscape);
74+
auto *GPT2 = cs.createTypeVariable(cs.getConstraintLocator({}),
75+
/*options=*/TVO_CanBindToNoEscape);
6076

61-
auto *GPT = cs.createTypeVariable(cs.getConstraintLocator({}),
62-
/*options=*/TVO_CanBindToNoEscape);
77+
cs.addConstraint(
78+
ConstraintKind::ConformsTo, GPT1, protocolTy1,
79+
cs.getConstraintLocator({}, LocatorPathElt::TypeParameterRequirement(
80+
0, RequirementKind::Conformance)));
6381

6482
cs.addConstraint(
65-
ConstraintKind::ConformsTo, GPT, protocolTy1,
83+
ConstraintKind::ConformsTo, GPT2, protocolTy2,
6684
cs.getConstraintLocator({}, LocatorPathElt::TypeParameterRequirement(
6785
0, RequirementKind::Conformance)));
6886

@@ -73,16 +91,114 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
7391
/*options=*/0);
7492

7593
cs.addConstraint(
76-
ConstraintKind::Conversion, typeVar, GPT,
94+
ConstraintKind::Conversion, typeVar, GPT1,
7795
cs.getConstraintLocator({}, LocatorPathElt::ContextualType()));
7896

7997
auto bindings = inferBindings(cs, typeVar);
8098
ASSERT_TRUE(bindings.Protocols.empty());
99+
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
100+
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
101+
{protocolTy1});
102+
}
103+
104+
// Now, let's make sure that protocol requirements could be propagated
105+
// down conversion/equality chains through multiple hops.
106+
{
107+
// GPT1 is a subtype of GPT2 and GPT2 is convertible to a target type
108+
// variable, target should get both protocols inferred - P1 & P2.
81109

82-
const auto &inferredProtocols = bindings.TransitiveProtocols;
83-
ASSERT_TRUE(bool(inferredProtocols));
84-
ASSERT_EQ(inferredProtocols->size(), (unsigned)1);
85-
ASSERT_TRUE(
86-
(*inferredProtocols->begin())->getSecondType()->isEqual(protocolTy1));
110+
auto *typeVar = cs.createTypeVariable(cs.getConstraintLocator({}),
111+
/*options=*/0);
112+
113+
cs.addConstraint(ConstraintKind::Subtype, GPT1, GPT2,
114+
cs.getConstraintLocator({}));
115+
116+
cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1,
117+
cs.getConstraintLocator({}));
118+
119+
auto bindings = inferBindings(cs, typeVar);
120+
ASSERT_TRUE(bindings.Protocols.empty());
121+
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
122+
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
123+
{protocolTy1, protocolTy2});
87124
}
88125
}
126+
127+
/// Let's try a more complicated situation where there protocols
128+
/// are inferred from multiple sources on different levels of
129+
/// convertion chain.
130+
///
131+
/// (P1) T0 T4 (T3) T6 (P4)
132+
/// \ / /
133+
/// T3 = T1 (P2) = T5
134+
/// \ /
135+
/// T2
136+
137+
TEST_F(SemaTest, TestComplexTransitiveProtocolInference) {
138+
ConstraintSystemOptions options;
139+
ConstraintSystem cs(DC, options);
140+
141+
auto *protocolTy1 = createProtocol("P1");
142+
auto *protocolTy2 = createProtocol("P2");
143+
auto *protocolTy3 = createProtocol("P3");
144+
auto *protocolTy4 = createProtocol("P4");
145+
146+
auto *nilLocator = cs.getConstraintLocator({});
147+
148+
auto typeVar0 = cs.createTypeVariable(nilLocator, /*options=*/0);
149+
auto typeVar1 = cs.createTypeVariable(nilLocator, /*options=*/0);
150+
auto typeVar2 = cs.createTypeVariable(nilLocator, /*options=*/0);
151+
// Allow this type variable to be bound to l-value type to prevent
152+
// it from being merged with the rest of the type variables.
153+
auto typeVar3 =
154+
cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue);
155+
auto typeVar4 = cs.createTypeVariable(nilLocator, /*options=*/0);
156+
auto typeVar5 =
157+
cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue);
158+
auto typeVar6 = cs.createTypeVariable(nilLocator, /*options=*/0);
159+
160+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar0, protocolTy1,
161+
nilLocator);
162+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar1, protocolTy2,
163+
nilLocator);
164+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar4, protocolTy3,
165+
nilLocator);
166+
cs.addConstraint(ConstraintKind::ConformsTo, typeVar6, protocolTy4,
167+
nilLocator);
168+
169+
// T3 <: T0, T3 <: T4
170+
cs.addConstraint(ConstraintKind::Conversion, typeVar3, typeVar0, nilLocator);
171+
cs.addConstraint(ConstraintKind::Conversion, typeVar3, typeVar4, nilLocator);
172+
173+
// T2 <: T3, T2 <: T1, T3 == T1
174+
cs.addConstraint(ConstraintKind::Subtype, typeVar2, typeVar3, nilLocator);
175+
cs.addConstraint(ConstraintKind::Conversion, typeVar2, typeVar1, nilLocator);
176+
cs.addConstraint(ConstraintKind::Equal, typeVar3, typeVar1, nilLocator);
177+
// T1 == T5, T <: T6
178+
cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar5, nilLocator);
179+
cs.addConstraint(ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator);
180+
181+
auto bindingsForT1 = inferBindings(cs, typeVar1);
182+
auto bindingsForT2 = inferBindings(cs, typeVar2);
183+
auto bindingsForT3 = inferBindings(cs, typeVar3);
184+
auto bindingsForT5 = inferBindings(cs, typeVar5);
185+
186+
ASSERT_TRUE(bool(bindingsForT1.TransitiveProtocols));
187+
verifyProtocolInferenceResults(*bindingsForT1.TransitiveProtocols,
188+
{protocolTy1, protocolTy3, protocolTy4});
189+
190+
ASSERT_TRUE(bool(bindingsForT2.TransitiveProtocols));
191+
verifyProtocolInferenceResults(
192+
*bindingsForT2.TransitiveProtocols,
193+
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});
194+
195+
ASSERT_TRUE(bool(bindingsForT3.TransitiveProtocols));
196+
verifyProtocolInferenceResults(
197+
*bindingsForT3.TransitiveProtocols,
198+
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});
199+
200+
ASSERT_TRUE(bool(bindingsForT5.TransitiveProtocols));
201+
verifyProtocolInferenceResults(
202+
*bindingsForT5.TransitiveProtocols,
203+
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});
204+
}

0 commit comments

Comments
 (0)