@@ -46,23 +46,41 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
4646 ASSERT_TRUE (binding.hasDefaultedLiteralProtocol ());
4747}
4848
49+ // Given a set of inferred protocol requirements, make sure that
50+ // all of the expected types are present.
51+ static void verifyProtocolInferenceResults (
52+ const llvm::SmallPtrSetImpl<Constraint *> &protocols,
53+ ArrayRef<Type> expectedTypes) {
54+ ASSERT_TRUE (protocols.size () >= expectedTypes.size ());
55+
56+ llvm::SmallPtrSet<Type, 2 > inferredProtocolTypes;
57+ for (auto *protocol : protocols)
58+ inferredProtocolTypes.insert (protocol->getSecondType ());
59+
60+ for (auto expectedTy : expectedTypes) {
61+ ASSERT_TRUE (inferredProtocolTypes.count (expectedTy));
62+ }
63+ }
64+
4965TEST_F (SemaTest, TestTransitiveProtocolInference) {
5066 ConstraintSystemOptions options;
5167 ConstraintSystem cs (DC, options);
5268
53- auto *PD1 =
54- new (Context) ProtocolDecl (DC, SourceLoc (), SourceLoc (),
55- Context.getIdentifier (" P1" ), /* Inherited=*/ {},
56- /* trailingWhere=*/ nullptr );
57- PD1->setImplicit ();
69+ auto *protocolTy1 = createProtocol (" P1" );
70+ auto *protocolTy2 = createProtocol (" P2" );
5871
59- auto *protocolTy1 = ProtocolType::get (PD1, Type (), Context);
72+ auto *GPT1 = cs.createTypeVariable (cs.getConstraintLocator ({}),
73+ /* options=*/ TVO_CanBindToNoEscape);
74+ auto *GPT2 = cs.createTypeVariable (cs.getConstraintLocator ({}),
75+ /* options=*/ TVO_CanBindToNoEscape);
6076
61- auto *GPT = cs.createTypeVariable (cs.getConstraintLocator ({}),
62- /* options=*/ TVO_CanBindToNoEscape);
77+ cs.addConstraint (
78+ ConstraintKind::ConformsTo, GPT1, protocolTy1,
79+ cs.getConstraintLocator ({}, LocatorPathElt::TypeParameterRequirement (
80+ 0 , RequirementKind::Conformance)));
6381
6482 cs.addConstraint (
65- ConstraintKind::ConformsTo, GPT, protocolTy1 ,
83+ ConstraintKind::ConformsTo, GPT2, protocolTy2 ,
6684 cs.getConstraintLocator ({}, LocatorPathElt::TypeParameterRequirement (
6785 0 , RequirementKind::Conformance)));
6886
@@ -73,16 +91,114 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
7391 /* options=*/ 0 );
7492
7593 cs.addConstraint (
76- ConstraintKind::Conversion, typeVar, GPT ,
94+ ConstraintKind::Conversion, typeVar, GPT1 ,
7795 cs.getConstraintLocator ({}, LocatorPathElt::ContextualType ()));
7896
7997 auto bindings = inferBindings (cs, typeVar);
8098 ASSERT_TRUE (bindings.Protocols .empty ());
99+ ASSERT_TRUE (bool (bindings.TransitiveProtocols ));
100+ verifyProtocolInferenceResults (*bindings.TransitiveProtocols ,
101+ {protocolTy1});
102+ }
103+
104+ // Now, let's make sure that protocol requirements could be propagated
105+ // down conversion/equality chains through multiple hops.
106+ {
107+ // GPT1 is a subtype of GPT2 and GPT2 is convertible to a target type
108+ // variable, target should get both protocols inferred - P1 & P2.
81109
82- const auto &inferredProtocols = bindings.TransitiveProtocols ;
83- ASSERT_TRUE (bool (inferredProtocols));
84- ASSERT_EQ (inferredProtocols->size (), (unsigned )1 );
85- ASSERT_TRUE (
86- (*inferredProtocols->begin ())->getSecondType ()->isEqual (protocolTy1));
110+ auto *typeVar = cs.createTypeVariable (cs.getConstraintLocator ({}),
111+ /* options=*/ 0 );
112+
113+ cs.addConstraint (ConstraintKind::Subtype, GPT1, GPT2,
114+ cs.getConstraintLocator ({}));
115+
116+ cs.addConstraint (ConstraintKind::Conversion, typeVar, GPT1,
117+ cs.getConstraintLocator ({}));
118+
119+ auto bindings = inferBindings (cs, typeVar);
120+ ASSERT_TRUE (bindings.Protocols .empty ());
121+ ASSERT_TRUE (bool (bindings.TransitiveProtocols ));
122+ verifyProtocolInferenceResults (*bindings.TransitiveProtocols ,
123+ {protocolTy1, protocolTy2});
87124 }
88125}
126+
127+ // / Let's try a more complicated situation where there protocols
128+ // / are inferred from multiple sources on different levels of
129+ // / convertion chain.
130+ // /
131+ // / (P1) T0 T4 (T3) T6 (P4)
132+ // / \ / /
133+ // / T3 = T1 (P2) = T5
134+ // / \ /
135+ // / T2
136+
137+ TEST_F (SemaTest, TestComplexTransitiveProtocolInference) {
138+ ConstraintSystemOptions options;
139+ ConstraintSystem cs (DC, options);
140+
141+ auto *protocolTy1 = createProtocol (" P1" );
142+ auto *protocolTy2 = createProtocol (" P2" );
143+ auto *protocolTy3 = createProtocol (" P3" );
144+ auto *protocolTy4 = createProtocol (" P4" );
145+
146+ auto *nilLocator = cs.getConstraintLocator ({});
147+
148+ auto typeVar0 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
149+ auto typeVar1 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
150+ auto typeVar2 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
151+ // Allow this type variable to be bound to l-value type to prevent
152+ // it from being merged with the rest of the type variables.
153+ auto typeVar3 =
154+ cs.createTypeVariable (nilLocator, /* options=*/ TVO_CanBindToLValue);
155+ auto typeVar4 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
156+ auto typeVar5 =
157+ cs.createTypeVariable (nilLocator, /* options=*/ TVO_CanBindToLValue);
158+ auto typeVar6 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
159+
160+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar0, protocolTy1,
161+ nilLocator);
162+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar1, protocolTy2,
163+ nilLocator);
164+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar4, protocolTy3,
165+ nilLocator);
166+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar6, protocolTy4,
167+ nilLocator);
168+
169+ // T3 <: T0, T3 <: T4
170+ cs.addConstraint (ConstraintKind::Conversion, typeVar3, typeVar0, nilLocator);
171+ cs.addConstraint (ConstraintKind::Conversion, typeVar3, typeVar4, nilLocator);
172+
173+ // T2 <: T3, T2 <: T1, T3 == T1
174+ cs.addConstraint (ConstraintKind::Subtype, typeVar2, typeVar3, nilLocator);
175+ cs.addConstraint (ConstraintKind::Conversion, typeVar2, typeVar1, nilLocator);
176+ cs.addConstraint (ConstraintKind::Equal, typeVar3, typeVar1, nilLocator);
177+ // T1 == T5, T <: T6
178+ cs.addConstraint (ConstraintKind::Equal, typeVar1, typeVar5, nilLocator);
179+ cs.addConstraint (ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator);
180+
181+ auto bindingsForT1 = inferBindings (cs, typeVar1);
182+ auto bindingsForT2 = inferBindings (cs, typeVar2);
183+ auto bindingsForT3 = inferBindings (cs, typeVar3);
184+ auto bindingsForT5 = inferBindings (cs, typeVar5);
185+
186+ ASSERT_TRUE (bool (bindingsForT1.TransitiveProtocols ));
187+ verifyProtocolInferenceResults (*bindingsForT1.TransitiveProtocols ,
188+ {protocolTy1, protocolTy3, protocolTy4});
189+
190+ ASSERT_TRUE (bool (bindingsForT2.TransitiveProtocols ));
191+ verifyProtocolInferenceResults (
192+ *bindingsForT2.TransitiveProtocols ,
193+ {protocolTy1, protocolTy2, protocolTy3, protocolTy4});
194+
195+ ASSERT_TRUE (bool (bindingsForT3.TransitiveProtocols ));
196+ verifyProtocolInferenceResults (
197+ *bindingsForT3.TransitiveProtocols ,
198+ {protocolTy1, protocolTy2, protocolTy3, protocolTy4});
199+
200+ ASSERT_TRUE (bool (bindingsForT5.TransitiveProtocols ));
201+ verifyProtocolInferenceResults (
202+ *bindingsForT5.TransitiveProtocols ,
203+ {protocolTy1, protocolTy2, protocolTy3, protocolTy4});
204+ }
0 commit comments