Skip to content

Commit e050294

Browse files
committed
[Associated type inference] Limit Failure inference to rethrows next()
When inferring a type witness for `AsyncIteratorProtocol` or `AsyncSequence`'s `Failure` associated type, don't infer from a generic parameter named `Failure`. Instead, use `next()` as a cue: if it `rethrows`, use `Failure` from one of the conformances; if it `throws`, use `any Error`. This is a more conservative inference rule, and addresses a failure to infer a `Failure` type witness for some fairly-obvious cases. Fixes rdar://122514816.
1 parent f9e6478 commit e050294

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,13 +1632,15 @@ next_witness:;
16321632
return result;
16331633
}
16341634

1635-
/// Determine whether this is AsyncIteratorProtocol.Failure associated type.
1635+
/// Determine whether this is AsyncIteratorProtocol.Failure or
1636+
/// AsyncSequenceProtoco.Failure associated type.
16361637
static bool isAsyncIteratorProtocolFailure(AssociatedTypeDecl *assocType) {
16371638
auto proto = assocType->getProtocol();
1638-
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol))
1639+
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) &&
1640+
!proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence))
16391641
return false;
16401642

1641-
return assocType->getName().str().equals("Failure");
1643+
return assocType->getName() == assocType->getASTContext().Id_Failure;
16421644
}
16431645

16441646
/// Determine whether this is AsyncIteratorProtocol.next() function.
@@ -2149,23 +2151,11 @@ llvm::Optional<AbstractTypeWitness>
21492151
AssociatedTypeInference::computeFailureTypeWitness(
21502152
AssociatedTypeDecl *assocType,
21512153
ArrayRef<std::pair<ValueDecl *, ValueDecl *>> valueWitnesses) const {
2152-
// Inference only applies to AsyncIteratorProtocol.Failure.
2154+
// Inference only applies to AsyncIteratorProtocol.Failure and
2155+
// AsyncSequence.Failure.
21532156
if (!isAsyncIteratorProtocolFailure(assocType))
21542157
return llvm::None;
21552158

2156-
// If there is a generic parameter named Failure, don't try to use next()
2157-
// to infer Failure.
2158-
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2159-
for (auto gp : genericSig.getGenericParams()) {
2160-
// Packs cannot witness associated type requirements.
2161-
if (gp->isParameterPack())
2162-
continue;
2163-
2164-
if (gp->getName() == assocType->getName())
2165-
return llvm::None;
2166-
}
2167-
}
2168-
21692159
// Look for AsyncIteratorProtocol.next() and infer the Failure type from
21702160
// it.
21712161
for (const auto &witness : valueWitnesses) {
@@ -2179,8 +2169,6 @@ AssociatedTypeInference::computeFailureTypeWitness(
21792169
if (!witnessFunc->getAttrs().hasAttribute<RethrowsAttr>())
21802170
return AbstractTypeWitness(assocType, ctx.getErrorExistentialType());
21812171

2182-
// Otherwise, we need to derive the Failure type from a type parameter
2183-
// that conforms to AsyncIteratorProtocol or AsyncSequence.
21842172
for (auto req : witnessFunc->getGenericSignature().getRequirements()) {
21852173
if (req.getKind() == RequirementKind::Conformance) {
21862174
auto proto = req.getProtocolDecl();
@@ -2206,7 +2194,8 @@ AssociatedTypeInference::computeFailureTypeWitness(
22062194
llvm::Optional<AbstractTypeWitness>
22072195
AssociatedTypeInference::computeDefaultTypeWitness(
22082196
AssociatedTypeDecl *assocType) const {
2209-
// Ignore the default for AsyncIteratorProtocol.Failure
2197+
// Ignore the default for AsyncIteratorProtocol.Failure and
2198+
// AsyncSequence.Failure.
22102199
if (isAsyncIteratorProtocolFailure(assocType))
22112200
return llvm::None;
22122201

@@ -2304,13 +2293,14 @@ AssociatedTypeInference::computeAbstractTypeWitness(
23042293
if (const auto &typeWitness = computeDefaultTypeWitness(assocType))
23052294
return typeWitness;
23062295

2296+
// Ignore the default for AsyncIteratorProtocol.Failure and
2297+
// AsyncSequence.Failure. We use the next() function to do inference.
2298+
if (isAsyncIteratorProtocolFailure(assocType))
2299+
return llvm::None;
2300+
23072301
// If there is a generic parameter of the named type, use that.
23082302
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2309-
bool wantAllGenericParams = isAsyncIteratorProtocolFailure(assocType);
2310-
auto genericParams = wantAllGenericParams
2311-
? genericSig.getGenericParams()
2312-
: genericSig.getInnermostGenericParams();
2313-
for (auto gp : genericParams) {
2303+
for (auto gp : genericSig.getInnermostGenericParams()) {
23142304
// Packs cannot witness associated type requirements.
23152305
if (gp->isParameterPack())
23162306
continue;
@@ -2342,6 +2332,11 @@ void AssociatedTypeInference::collectAbstractTypeWitnesses(
23422332
// through same-type requirements of protocols.
23432333
if (auto genericSig = dc->getGenericSignatureOfContext()) {
23442334
for (auto *const assocType : unresolvedAssocTypes) {
2335+
// Ignore the generic parameters for AsyncIteratorProtocol.Failure and
2336+
// AsyncSequence.Failure.
2337+
if (isAsyncIteratorProtocolFailure(assocType))
2338+
continue;
2339+
23452340
for (auto *gp : genericSig.getInnermostGenericParams()) {
23462341
// Packs cannot witness associated type requirements.
23472342
if (gp->isParameterPack())

test/Concurrency/async_iterator_inference.swift

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ struct TS: AsyncSequence {
2121
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
2222
}
2323

24+
@available(SwiftStdlib 5.1, *)
25+
struct SpecificTS<F: Error>: AsyncSequence {
26+
typealias Element = Int
27+
typealias Failure = F
28+
struct AsyncIterator: AsyncIteratorProtocol {
29+
typealias Failure = F
30+
mutating func next() async throws -> Int? { nil }
31+
}
32+
33+
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
34+
}
35+
2436
@available(SwiftStdlib 5.1, *)
2537
struct GenericTS<Failure: Error>: AsyncSequence {
2638
typealias Element = Int
@@ -42,16 +54,33 @@ struct SequenceAdapter<Base: AsyncSequence>: AsyncSequence {
4254
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
4355
}
4456

57+
public struct NormalThrowingAsyncSequence<Element, Failure>: AsyncSequence {
58+
private let iteratorMaker: () -> AsyncIterator
59+
60+
public struct AsyncIterator: AsyncIteratorProtocol {
61+
let nextMaker: () async throws -> Element?
62+
public mutating func next() async throws -> Element? {
63+
try await nextMaker()
64+
}
65+
}
66+
67+
public func makeAsyncIterator() -> AsyncIterator {
68+
iteratorMaker()
69+
}
70+
}
71+
72+
4573
enum MyError: Error {
4674
case fail
4775
}
4876

4977
@available(SwiftStdlib 5.1, *)
50-
func testAssocTypeInference(sf: S.Failure, tsf: TS.Failure, gtsf1: GenericTS<MyError>.Failure, adapter: SequenceAdapter<GenericTS<MyError>>.Failure) {
78+
func testAssocTypeInference(sf: S.Failure, tsf: TS.Failure, gtsf1: GenericTS<MyError>.Failure, adapter: SequenceAdapter<SpecificTS<MyError>>.Failure, ntas: NormalThrowingAsyncSequence<String, MyError>.Failure) {
5179
let _: Int = sf // expected-error{{cannot convert value of type 'S.Failure' (aka 'Never') to specified type 'Int'}}
5280
let _: Int = tsf // expected-error{{cannot convert value of type 'TS.Failure' (aka 'any Error') to specified type 'Int'}}
53-
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.Failure' (aka 'MyError') to specified type 'Int'}}
54-
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<GenericTS<MyError>>.Failure' (aka 'MyError') to specified type 'Int'}}
81+
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.Failure' (aka 'any Error') to specified type 'Int'}}
82+
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<SpecificTS<MyError>>.Failure' (aka 'MyError') to specified type 'Int'}}
83+
let _: Int = ntas // expected-error{{cannot convert value of type 'NormalThrowingAsyncSequence<String, MyError>.Failure' (aka 'any Error') to specified type 'Int'}}
5584
}
5685

5786

@@ -66,9 +95,9 @@ case boom
6695

6796

6897
@available(SwiftStdlib 5.1, *)
69-
func testMyError(s: GenericTS<MyError>, so: GenericTS<OtherError>) async throws(MyError) {
98+
func testMyError(s: SpecificTS<MyError>, so: SpecificTS<OtherError>) async throws(MyError) {
7099
for try await x in s { _ = x }
71100

72101
for try await x in so { _ = x }
73-
// expected-error@-1{{thrown expression type 'OtherError' cannot be converted to error type 'MyError'}}
102+
// expected-error@-1{{thrown expression type 'SpecificTS<OtherError>.AsyncIterator.Failure' (aka 'OtherError') cannot be converted to error type 'MyError'}}
74103
}

0 commit comments

Comments
 (0)