Skip to content

Commit c1c68d5

Browse files
authored
Merge pull request #71494 from DougGregor/failure-inference-failed
2 parents a186a51 + e050294 commit c1c68d5

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,13 +1638,15 @@ next_witness:;
16381638
return result;
16391639
}
16401640

1641-
/// Determine whether this is AsyncIteratorProtocol.Failure associated type.
1641+
/// Determine whether this is AsyncIteratorProtocol.Failure or
1642+
/// AsyncSequenceProtoco.Failure associated type.
16421643
static bool isAsyncIteratorProtocolFailure(AssociatedTypeDecl *assocType) {
16431644
auto proto = assocType->getProtocol();
1644-
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol))
1645+
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) &&
1646+
!proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence))
16451647
return false;
16461648

1647-
return assocType->getName().str().equals("Failure");
1649+
return assocType->getName() == assocType->getASTContext().Id_Failure;
16481650
}
16491651

16501652
/// Determine whether this is AsyncIteratorProtocol.next() function.
@@ -2155,23 +2157,11 @@ llvm::Optional<AbstractTypeWitness>
21552157
AssociatedTypeInference::computeFailureTypeWitness(
21562158
AssociatedTypeDecl *assocType,
21572159
ArrayRef<std::pair<ValueDecl *, ValueDecl *>> valueWitnesses) const {
2158-
// Inference only applies to AsyncIteratorProtocol.Failure.
2160+
// Inference only applies to AsyncIteratorProtocol.Failure and
2161+
// AsyncSequence.Failure.
21592162
if (!isAsyncIteratorProtocolFailure(assocType))
21602163
return llvm::None;
21612164

2162-
// If there is a generic parameter named Failure, don't try to use next()
2163-
// to infer Failure.
2164-
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2165-
for (auto gp : genericSig.getGenericParams()) {
2166-
// Packs cannot witness associated type requirements.
2167-
if (gp->isParameterPack())
2168-
continue;
2169-
2170-
if (gp->getName() == assocType->getName())
2171-
return llvm::None;
2172-
}
2173-
}
2174-
21752165
// Look for AsyncIteratorProtocol.next() and infer the Failure type from
21762166
// it.
21772167
for (const auto &witness : valueWitnesses) {
@@ -2185,8 +2175,6 @@ AssociatedTypeInference::computeFailureTypeWitness(
21852175
if (!witnessFunc->getAttrs().hasAttribute<RethrowsAttr>())
21862176
return AbstractTypeWitness(assocType, ctx.getErrorExistentialType());
21872177

2188-
// Otherwise, we need to derive the Failure type from a type parameter
2189-
// that conforms to AsyncIteratorProtocol or AsyncSequence.
21902178
for (auto req : witnessFunc->getGenericSignature().getRequirements()) {
21912179
if (req.getKind() == RequirementKind::Conformance) {
21922180
auto proto = req.getProtocolDecl();
@@ -2212,7 +2200,8 @@ AssociatedTypeInference::computeFailureTypeWitness(
22122200
llvm::Optional<AbstractTypeWitness>
22132201
AssociatedTypeInference::computeDefaultTypeWitness(
22142202
AssociatedTypeDecl *assocType) const {
2215-
// Ignore the default for AsyncIteratorProtocol.Failure
2203+
// Ignore the default for AsyncIteratorProtocol.Failure and
2204+
// AsyncSequence.Failure.
22162205
if (isAsyncIteratorProtocolFailure(assocType))
22172206
return llvm::None;
22182207

@@ -2310,13 +2299,14 @@ AssociatedTypeInference::computeAbstractTypeWitness(
23102299
if (const auto &typeWitness = computeDefaultTypeWitness(assocType))
23112300
return typeWitness;
23122301

2302+
// Ignore the default for AsyncIteratorProtocol.Failure and
2303+
// AsyncSequence.Failure. We use the next() function to do inference.
2304+
if (isAsyncIteratorProtocolFailure(assocType))
2305+
return llvm::None;
2306+
23132307
// If there is a generic parameter of the named type, use that.
23142308
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2315-
bool wantAllGenericParams = isAsyncIteratorProtocolFailure(assocType);
2316-
auto genericParams = wantAllGenericParams
2317-
? genericSig.getGenericParams()
2318-
: genericSig.getInnermostGenericParams();
2319-
for (auto gp : genericParams) {
2309+
for (auto gp : genericSig.getInnermostGenericParams()) {
23202310
// Packs cannot witness associated type requirements.
23212311
if (gp->isParameterPack())
23222312
continue;
@@ -2348,6 +2338,11 @@ void AssociatedTypeInference::collectAbstractTypeWitnesses(
23482338
// through same-type requirements of protocols.
23492339
if (auto genericSig = dc->getGenericSignatureOfContext()) {
23502340
for (auto *const assocType : unresolvedAssocTypes) {
2341+
// Ignore the generic parameters for AsyncIteratorProtocol.Failure and
2342+
// AsyncSequence.Failure.
2343+
if (isAsyncIteratorProtocolFailure(assocType))
2344+
continue;
2345+
23512346
for (auto *gp : genericSig.getInnermostGenericParams()) {
23522347
// Packs cannot witness associated type requirements.
23532348
if (gp->isParameterPack())

test/Concurrency/async_iterator_inference.swift

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ struct TS: AsyncSequence {
2121
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
2222
}
2323

24+
@available(SwiftStdlib 5.1, *)
25+
struct SpecificTS<F: Error>: AsyncSequence {
26+
typealias Element = Int
27+
typealias Failure = F
28+
struct AsyncIterator: AsyncIteratorProtocol {
29+
typealias Failure = F
30+
mutating func next() async throws -> Int? { nil }
31+
}
32+
33+
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
34+
}
35+
2436
@available(SwiftStdlib 5.1, *)
2537
struct GenericTS<Failure: Error>: AsyncSequence {
2638
typealias Element = Int
@@ -42,16 +54,33 @@ struct SequenceAdapter<Base: AsyncSequence>: AsyncSequence {
4254
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
4355
}
4456

57+
public struct NormalThrowingAsyncSequence<Element, Failure>: AsyncSequence {
58+
private let iteratorMaker: () -> AsyncIterator
59+
60+
public struct AsyncIterator: AsyncIteratorProtocol {
61+
let nextMaker: () async throws -> Element?
62+
public mutating func next() async throws -> Element? {
63+
try await nextMaker()
64+
}
65+
}
66+
67+
public func makeAsyncIterator() -> AsyncIterator {
68+
iteratorMaker()
69+
}
70+
}
71+
72+
4573
enum MyError: Error {
4674
case fail
4775
}
4876

4977
@available(SwiftStdlib 5.1, *)
50-
func testAssocTypeInference(sf: S.Failure, tsf: TS.Failure, gtsf1: GenericTS<MyError>.Failure, adapter: SequenceAdapter<GenericTS<MyError>>.Failure) {
78+
func testAssocTypeInference(sf: S.Failure, tsf: TS.Failure, gtsf1: GenericTS<MyError>.Failure, adapter: SequenceAdapter<SpecificTS<MyError>>.Failure, ntas: NormalThrowingAsyncSequence<String, MyError>.Failure) {
5179
let _: Int = sf // expected-error{{cannot convert value of type 'S.Failure' (aka 'Never') to specified type 'Int'}}
5280
let _: Int = tsf // expected-error{{cannot convert value of type 'TS.Failure' (aka 'any Error') to specified type 'Int'}}
53-
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.Failure' (aka 'MyError') to specified type 'Int'}}
54-
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<GenericTS<MyError>>.Failure' (aka 'MyError') to specified type 'Int'}}
81+
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.Failure' (aka 'any Error') to specified type 'Int'}}
82+
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<SpecificTS<MyError>>.Failure' (aka 'MyError') to specified type 'Int'}}
83+
let _: Int = ntas // expected-error{{cannot convert value of type 'NormalThrowingAsyncSequence<String, MyError>.Failure' (aka 'any Error') to specified type 'Int'}}
5584
}
5685

5786

@@ -66,9 +95,9 @@ case boom
6695

6796

6897
@available(SwiftStdlib 5.1, *)
69-
func testMyError(s: GenericTS<MyError>, so: GenericTS<OtherError>) async throws(MyError) {
98+
func testMyError(s: SpecificTS<MyError>, so: SpecificTS<OtherError>) async throws(MyError) {
7099
for try await x in s { _ = x }
71100

72101
for try await x in so { _ = x }
73-
// expected-error@-1{{thrown expression type 'OtherError' cannot be converted to error type 'MyError'}}
102+
// expected-error@-1{{thrown expression type 'SpecificTS<OtherError>.AsyncIterator.Failure' (aka 'OtherError') cannot be converted to error type 'MyError'}}
74103
}

0 commit comments

Comments
 (0)