@@ -46,23 +46,41 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
46
46
ASSERT_TRUE (binding.hasDefaultedLiteralProtocol ());
47
47
}
48
48
49
+ // Given a set of inferred protocol requirements, make sure that
50
+ // all of the expected types are present.
51
+ static void verifyProtocolInferenceResults (
52
+ const llvm::SmallPtrSetImpl<Constraint *> &protocols,
53
+ ArrayRef<Type> expectedTypes) {
54
+ ASSERT_TRUE (protocols.size () >= expectedTypes.size ());
55
+
56
+ llvm::SmallPtrSet<Type, 2 > inferredProtocolTypes;
57
+ for (auto *protocol : protocols)
58
+ inferredProtocolTypes.insert (protocol->getSecondType ());
59
+
60
+ for (auto expectedTy : expectedTypes) {
61
+ ASSERT_TRUE (inferredProtocolTypes.count (expectedTy));
62
+ }
63
+ }
64
+
49
65
TEST_F (SemaTest, TestTransitiveProtocolInference) {
50
66
ConstraintSystemOptions options;
51
67
ConstraintSystem cs (DC, options);
52
68
53
- auto *PD1 =
54
- new (Context) ProtocolDecl (DC, SourceLoc (), SourceLoc (),
55
- Context.getIdentifier (" P1" ), /* Inherited=*/ {},
56
- /* trailingWhere=*/ nullptr );
57
- PD1->setImplicit ();
69
+ auto *protocolTy1 = createProtocol (" P1" );
70
+ auto *protocolTy2 = createProtocol (" P2" );
58
71
59
- auto *protocolTy1 = ProtocolType::get (PD1, Type (), Context);
72
+ auto *GPT1 = cs.createTypeVariable (cs.getConstraintLocator ({}),
73
+ /* options=*/ TVO_CanBindToNoEscape);
74
+ auto *GPT2 = cs.createTypeVariable (cs.getConstraintLocator ({}),
75
+ /* options=*/ TVO_CanBindToNoEscape);
60
76
61
- auto *GPT = cs.createTypeVariable (cs.getConstraintLocator ({}),
62
- /* options=*/ TVO_CanBindToNoEscape);
77
+ cs.addConstraint (
78
+ ConstraintKind::ConformsTo, GPT1, protocolTy1,
79
+ cs.getConstraintLocator ({}, LocatorPathElt::TypeParameterRequirement (
80
+ 0 , RequirementKind::Conformance)));
63
81
64
82
cs.addConstraint (
65
- ConstraintKind::ConformsTo, GPT, protocolTy1 ,
83
+ ConstraintKind::ConformsTo, GPT2, protocolTy2 ,
66
84
cs.getConstraintLocator ({}, LocatorPathElt::TypeParameterRequirement (
67
85
0 , RequirementKind::Conformance)));
68
86
@@ -73,16 +91,114 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
73
91
/* options=*/ 0 );
74
92
75
93
cs.addConstraint (
76
- ConstraintKind::Conversion, typeVar, GPT ,
94
+ ConstraintKind::Conversion, typeVar, GPT1 ,
77
95
cs.getConstraintLocator ({}, LocatorPathElt::ContextualType ()));
78
96
79
97
auto bindings = inferBindings (cs, typeVar);
80
98
ASSERT_TRUE (bindings.Protocols .empty ());
99
+ ASSERT_TRUE (bool (bindings.TransitiveProtocols ));
100
+ verifyProtocolInferenceResults (*bindings.TransitiveProtocols ,
101
+ {protocolTy1});
102
+ }
103
+
104
+ // Now, let's make sure that protocol requirements could be propagated
105
+ // down conversion/equality chains through multiple hops.
106
+ {
107
+ // GPT1 is a subtype of GPT2 and GPT2 is convertible to a target type
108
+ // variable, target should get both protocols inferred - P1 & P2.
81
109
82
- const auto &inferredProtocols = bindings.TransitiveProtocols ;
83
- ASSERT_TRUE (bool (inferredProtocols));
84
- ASSERT_EQ (inferredProtocols->size (), (unsigned )1 );
85
- ASSERT_TRUE (
86
- (*inferredProtocols->begin ())->getSecondType ()->isEqual (protocolTy1));
110
+ auto *typeVar = cs.createTypeVariable (cs.getConstraintLocator ({}),
111
+ /* options=*/ 0 );
112
+
113
+ cs.addConstraint (ConstraintKind::Subtype, GPT1, GPT2,
114
+ cs.getConstraintLocator ({}));
115
+
116
+ cs.addConstraint (ConstraintKind::Conversion, typeVar, GPT1,
117
+ cs.getConstraintLocator ({}));
118
+
119
+ auto bindings = inferBindings (cs, typeVar);
120
+ ASSERT_TRUE (bindings.Protocols .empty ());
121
+ ASSERT_TRUE (bool (bindings.TransitiveProtocols ));
122
+ verifyProtocolInferenceResults (*bindings.TransitiveProtocols ,
123
+ {protocolTy1, protocolTy2});
87
124
}
88
125
}
126
+
127
+ // / Let's try a more complicated situation where there protocols
128
+ // / are inferred from multiple sources on different levels of
129
+ // / convertion chain.
130
+ // /
131
+ // / (P1) T0 T4 (T3) T6 (P4)
132
+ // / \ / /
133
+ // / T3 = T1 (P2) = T5
134
+ // / \ /
135
+ // / T2
136
+
137
+ TEST_F (SemaTest, TestComplexTransitiveProtocolInference) {
138
+ ConstraintSystemOptions options;
139
+ ConstraintSystem cs (DC, options);
140
+
141
+ auto *protocolTy1 = createProtocol (" P1" );
142
+ auto *protocolTy2 = createProtocol (" P2" );
143
+ auto *protocolTy3 = createProtocol (" P3" );
144
+ auto *protocolTy4 = createProtocol (" P4" );
145
+
146
+ auto *nilLocator = cs.getConstraintLocator ({});
147
+
148
+ auto typeVar0 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
149
+ auto typeVar1 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
150
+ auto typeVar2 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
151
+ // Allow this type variable to be bound to l-value type to prevent
152
+ // it from being merged with the rest of the type variables.
153
+ auto typeVar3 =
154
+ cs.createTypeVariable (nilLocator, /* options=*/ TVO_CanBindToLValue);
155
+ auto typeVar4 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
156
+ auto typeVar5 =
157
+ cs.createTypeVariable (nilLocator, /* options=*/ TVO_CanBindToLValue);
158
+ auto typeVar6 = cs.createTypeVariable (nilLocator, /* options=*/ 0 );
159
+
160
+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar0, protocolTy1,
161
+ nilLocator);
162
+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar1, protocolTy2,
163
+ nilLocator);
164
+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar4, protocolTy3,
165
+ nilLocator);
166
+ cs.addConstraint (ConstraintKind::ConformsTo, typeVar6, protocolTy4,
167
+ nilLocator);
168
+
169
+ // T3 <: T0, T3 <: T4
170
+ cs.addConstraint (ConstraintKind::Conversion, typeVar3, typeVar0, nilLocator);
171
+ cs.addConstraint (ConstraintKind::Conversion, typeVar3, typeVar4, nilLocator);
172
+
173
+ // T2 <: T3, T2 <: T1, T3 == T1
174
+ cs.addConstraint (ConstraintKind::Subtype, typeVar2, typeVar3, nilLocator);
175
+ cs.addConstraint (ConstraintKind::Conversion, typeVar2, typeVar1, nilLocator);
176
+ cs.addConstraint (ConstraintKind::Equal, typeVar3, typeVar1, nilLocator);
177
+ // T1 == T5, T <: T6
178
+ cs.addConstraint (ConstraintKind::Equal, typeVar1, typeVar5, nilLocator);
179
+ cs.addConstraint (ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator);
180
+
181
+ auto bindingsForT1 = inferBindings (cs, typeVar1);
182
+ auto bindingsForT2 = inferBindings (cs, typeVar2);
183
+ auto bindingsForT3 = inferBindings (cs, typeVar3);
184
+ auto bindingsForT5 = inferBindings (cs, typeVar5);
185
+
186
+ ASSERT_TRUE (bool (bindingsForT1.TransitiveProtocols ));
187
+ verifyProtocolInferenceResults (*bindingsForT1.TransitiveProtocols ,
188
+ {protocolTy1, protocolTy3, protocolTy4});
189
+
190
+ ASSERT_TRUE (bool (bindingsForT2.TransitiveProtocols ));
191
+ verifyProtocolInferenceResults (
192
+ *bindingsForT2.TransitiveProtocols ,
193
+ {protocolTy1, protocolTy2, protocolTy3, protocolTy4});
194
+
195
+ ASSERT_TRUE (bool (bindingsForT3.TransitiveProtocols ));
196
+ verifyProtocolInferenceResults (
197
+ *bindingsForT3.TransitiveProtocols ,
198
+ {protocolTy1, protocolTy2, protocolTy3, protocolTy4});
199
+
200
+ ASSERT_TRUE (bool (bindingsForT5.TransitiveProtocols ));
201
+ verifyProtocolInferenceResults (
202
+ *bindingsForT5.TransitiveProtocols ,
203
+ {protocolTy1, protocolTy2, protocolTy3, protocolTy4});
204
+ }
0 commit comments