19
19
20
20
#include " swift/AST/Decl.h"
21
21
#include " swift/AST/GenericSignature.h"
22
+ #include " swift/AST/NameLookupRequests.h"
22
23
#include " swift/AST/ProtocolConformance.h"
23
24
#include " swift/AST/SubstitutionMap.h"
24
25
#include " swift/AST/TypeMatcher.h"
@@ -206,20 +207,22 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
206
207
if (extendedNominal == nullptr )
207
208
return true ;
208
209
209
- // FIXME: The extension may not have a generic signature set up yet as
210
- // resolving signatures may trigger associated type inference. This cycle
211
- // is now detectable and we should look into untangling it
212
- // - see rdar://55263708
213
- if (!extension->hasComputedGenericSignature ())
214
- return true ;
215
-
216
- // Retrieve the generic signature of the extension.
217
- const auto extensionSig = extension->getGenericSignature ();
210
+ auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
218
211
219
212
// If the extension is bound to the nominal the conformance is
220
213
// declared on, it is viable for inference when its conditional
221
214
// requirements are satisfied by those of the conformance context.
222
- if (!isa<ProtocolDecl>(extendedNominal)) {
215
+ if (!proto) {
216
+ // FIXME: The extension may not have a generic signature set up yet as
217
+ // resolving signatures may trigger associated type inference. This cycle
218
+ // is now detectable and we should look into untangling it
219
+ // - see rdar://55263708
220
+ if (!extension->hasComputedGenericSignature ())
221
+ return true ;
222
+
223
+ // Retrieve the generic signature of the extension.
224
+ const auto extensionSig = extension->getGenericSignature ();
225
+
223
226
// Extensions of non-generic nominals are always viable for inference.
224
227
if (!extensionSig)
225
228
return true ;
@@ -235,30 +238,29 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
235
238
// in the first place. Only check conformances on the `Self` type,
236
239
// because those have to be explicitly declared on the type somewhere
237
240
// so won't be affected by whatever answer inference comes up with.
238
- auto selfTy = extension->getSelfInterfaceType ();
239
- for (const Requirement &reqt : extensionSig->getRequirements ()) {
240
- switch (reqt.getKind ()) {
241
- case RequirementKind::Conformance:
242
- case RequirementKind::Superclass:
243
- // FIXME: This is the wrong check
244
- if (selfTy->isEqual (reqt.getFirstType ()) &&
245
- !TypeChecker::isSubtypeOf (conformance->getType (),
246
- reqt.getSecondType (), dc))
247
- return false ;
248
- break ;
241
+ auto *module = dc->getParentModule ();
242
+ auto checkConformance = [&](ProtocolDecl *proto) {
243
+ auto otherConf = module ->lookupConformance (conformance->getType (),
244
+ proto);
245
+ return (otherConf && otherConf.getConditionalRequirements ().empty ());
246
+ };
249
247
250
- case RequirementKind::Layout:
251
- case RequirementKind::SameType:
252
- break ;
248
+ // First check the extended protocol itself.
249
+ if (!checkConformance (proto))
250
+ return false ;
251
+
252
+ // Now check any additional bounds on 'Self' from the where clause.
253
+ auto bounds = getSelfBoundsFromWhereClause (extension);
254
+ for (auto *decl : bounds.decls ) {
255
+ if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
256
+ if (!checkConformance (proto))
257
+ return false ;
253
258
}
254
259
}
255
260
256
261
return true ;
257
262
};
258
263
259
- auto typeInContext =
260
- conformance->getDeclContext ()->mapTypeIntoContext (conformance->getType ());
261
-
262
264
for (auto witness :
263
265
checker.lookupValueWitnesses (req, /* ignoringNames=*/ nullptr )) {
264
266
LLVM_DEBUG (llvm::dbgs () << " Inferring associated types from decl:\n " ;
@@ -314,6 +316,10 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
314
316
if (!associatedTypesAreSameEquivalenceClass (dmt->getAssocType (),
315
317
result.first ))
316
318
return false ;
319
+
320
+ auto typeInContext =
321
+ conformance->getDeclContext ()->mapTypeIntoContext (conformance->getType ());
322
+
317
323
if (!dmt->getBase ()->isEqual (typeInContext))
318
324
return false ;
319
325
0 commit comments