@@ -2448,38 +2448,38 @@ AssociatedTypeInference::computeFailureTypeWitness(
2448
2448
// Look for AsyncIteratorProtocol.next() and infer the Failure type from
2449
2449
// it.
2450
2450
for (const auto &witness : valueWitnesses) {
2451
- if (isAsyncIteratorProtocolNext (witness.first )) {
2452
- // We use a dyn_cast_or_null since we can get a nullptr here if we fail to
2453
- // match a witness. In such a case, we should just fail here.
2454
- if (auto witnessFunc = dyn_cast_or_null<AbstractFunctionDecl>(witness.second )) {
2455
- auto thrownError = witnessFunc->getEffectiveThrownErrorType ();
2456
-
2457
- // If it doesn't throw, Failure == Never.
2458
- if (!thrownError)
2459
- return AbstractTypeWitness (assocType, ctx.getNeverType ());
2460
-
2461
- // If it isn't 'rethrows', use the thrown error type;.
2462
- if (!witnessFunc->getAttrs ().hasAttribute <RethrowsAttr>()) {
2463
- return AbstractTypeWitness (assocType,
2464
- dc->mapTypeIntoContext (*thrownError));
2465
- }
2451
+ if (!isAsyncIteratorProtocolNext (witness.first ))
2452
+ continue ;
2466
2453
2467
- for (auto req : witnessFunc->getGenericSignature ().getRequirements ()) {
2468
- if (req.getKind () == RequirementKind::Conformance) {
2469
- auto proto = req.getProtocolDecl ();
2470
- if (proto->isSpecificProtocol (KnownProtocolKind::AsyncIteratorProtocol) ||
2471
- proto->isSpecificProtocol (KnownProtocolKind::AsyncSequence)) {
2472
- auto failureAssocType = proto->getAssociatedType (ctx.Id_Failure );
2473
- auto failureType = DependentMemberType::get (req.getFirstType (), failureAssocType);
2474
- return AbstractTypeWitness (assocType, dc->mapTypeIntoContext (failureType));
2475
- }
2454
+ if (!witness.second || witness.second ->getDeclContext () != dc)
2455
+ continue ;
2456
+
2457
+ if (auto witnessFunc = dyn_cast<AbstractFunctionDecl>(witness.second )) {
2458
+ auto thrownError = witnessFunc->getEffectiveThrownErrorType ();
2459
+
2460
+ // If it doesn't throw, Failure == Never.
2461
+ if (!thrownError)
2462
+ return AbstractTypeWitness (assocType, ctx.getNeverType ());
2463
+
2464
+ // If it isn't 'rethrows', use the thrown error type;.
2465
+ if (!witnessFunc->getAttrs ().hasAttribute <RethrowsAttr>()) {
2466
+ return AbstractTypeWitness (assocType,
2467
+ dc->mapTypeIntoContext (*thrownError));
2468
+ }
2469
+
2470
+ for (auto req : witnessFunc->getGenericSignature ().getRequirements ()) {
2471
+ if (req.getKind () == RequirementKind::Conformance) {
2472
+ auto proto = req.getProtocolDecl ();
2473
+ if (proto->isSpecificProtocol (KnownProtocolKind::AsyncIteratorProtocol) ||
2474
+ proto->isSpecificProtocol (KnownProtocolKind::AsyncSequence)) {
2475
+ auto failureAssocType = proto->getAssociatedType (ctx.Id_Failure );
2476
+ auto failureType = DependentMemberType::get (req.getFirstType (), failureAssocType);
2477
+ return AbstractTypeWitness (assocType, dc->mapTypeIntoContext (failureType));
2476
2478
}
2477
2479
}
2478
-
2479
- return AbstractTypeWitness (assocType, ctx.getErrorExistentialType ());
2480
2480
}
2481
2481
2482
- break ;
2482
+ return AbstractTypeWitness (assocType, ctx. getErrorExistentialType ()) ;
2483
2483
}
2484
2484
}
2485
2485
0 commit comments