Skip to content

Commit a5bdb12

Browse files
committed
Adopt typed throws in AsyncIteratorProtocol and AsyncSequence
Introduce a new associated type `Failure` into the two protocols involved in async sequences, which represents the type thrown when the sequence fails. Introduce a defaulted `_nextElement()` operations that throws `Failure` or produces the next element of the sequence. Provide a default implementation of `_nextElement()` in terms of `next()` that force-cases the thrown error to the `Failure` type. Introduce special associated type inference logic for the `Failure` type of an `AsyncIteratorProtocol` conformance when there is no specific _nextElement()` witness. This inference logic looks at the witness for `next()`: * If `next()` throws nothing, `Failure` is inferred to `Never`. * If `next()` throws, `Failure` is inferred to `any Error`. * If `next()` rethrows, `Failure` is inferred to `T.Failure`, where `T` is the first type parameter with a conformance to either `AsyncSequence` or `AsyncIteratorProtocol`. The default implementation and the inference rule, together, allow existing async sequences to continue working as before, and set us up for changing the contract of the `async for` loop to use `_nextIterator()` rather than `next()`. Give `AsyncSequence` and `AsyncIteratorProtocol` primary associated types for the element and failure types, which will allow them to be used more generally with existential and opaque types.
1 parent 390d510 commit a5bdb12

File tree

7 files changed

+220
-7
lines changed

7 files changed

+220
-7
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,14 @@ class swift::AssociatedTypeInference {
821821
llvm::Optional<AbstractTypeWitness>
822822
computeDefaultTypeWitness(AssociatedTypeDecl *assocType) const;
823823

824+
/// Compute type witnesses for the Failure type from the
825+
/// AsyncSequence or AsyncIteratorProtocol
826+
llvm::Optional<AbstractTypeWitness>
827+
computeFailureTypeWitness(
828+
AssociatedTypeDecl *assocType,
829+
ArrayRef<std::pair<ValueDecl *, ValueDecl *>> valueWitnesses
830+
) const;
831+
824832
/// Compute the "derived" type witness for an associated type that is
825833
/// known to the compiler.
826834
std::pair<Type, TypeDecl *>
@@ -1370,6 +1378,25 @@ next_witness:;
13701378
return result;
13711379
}
13721380

1381+
/// Determine whether this is AsyncIteratorProtocol.Failure associated type.
1382+
static bool isAsyncIteratorProtocolFailure(AssociatedTypeDecl *assocType) {
1383+
auto proto = assocType->getProtocol();
1384+
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol))
1385+
return false;
1386+
1387+
return assocType->getName().str().equals("Failure");
1388+
}
1389+
1390+
/// Determine whether this is AsyncIteratorProtocol.next() function.
1391+
static bool isAsyncIteratorProtocolNext(ValueDecl *req) {
1392+
auto proto = dyn_cast<ProtocolDecl>(req->getDeclContext());
1393+
if (!proto ||
1394+
!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol))
1395+
return false;
1396+
1397+
return req->getName().getBaseName() == req->getASTContext().Id_next;
1398+
}
1399+
13731400
InferredAssociatedTypes
13741401
AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
13751402
const llvm::SetVector<AssociatedTypeDecl *> &assocTypes) {
@@ -1415,7 +1442,8 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
14151442
TinyPtrVector<AssociatedTypeDecl *>());
14161443
if (llvm::find_if(referenced, [&](AssociatedTypeDecl *const assocType) {
14171444
return assocTypes.count(assocType);
1418-
}) == referenced.end())
1445+
}) == referenced.end() &&
1446+
!isAsyncIteratorProtocolNext(req))
14191447
continue;
14201448
}
14211449

@@ -1845,9 +1873,71 @@ Type AssociatedTypeInference::computeFixedTypeWitness(
18451873
return resultType;
18461874
}
18471875

1876+
llvm::Optional<AbstractTypeWitness>
1877+
AssociatedTypeInference::computeFailureTypeWitness(
1878+
AssociatedTypeDecl *assocType,
1879+
ArrayRef<std::pair<ValueDecl *, ValueDecl *>> valueWitnesses) const {
1880+
// Inference only applies to AsyncIteratorProtocol.Failure.
1881+
if (!isAsyncIteratorProtocolFailure(assocType))
1882+
return llvm::None;
1883+
1884+
// If there is a generic parameter named Failure, don't try to use next()
1885+
// to infer Failure.
1886+
if (auto genericSig = dc->getGenericSignatureOfContext()) {
1887+
for (auto gp : genericSig.getGenericParams()) {
1888+
// Packs cannot witness associated type requirements.
1889+
if (gp->isParameterPack())
1890+
continue;
1891+
1892+
if (gp->getName() == assocType->getName())
1893+
return llvm::None;
1894+
}
1895+
}
1896+
1897+
// Look for AsyncIteratorProtocol.next() and infer the Failure type from
1898+
// it.
1899+
for (const auto &witness : valueWitnesses) {
1900+
if (isAsyncIteratorProtocolNext(witness.first)) {
1901+
if (auto witnessFunc = dyn_cast<AbstractFunctionDecl>(witness.second)) {
1902+
// If it doesn't throw, Failure == Never.
1903+
if (!witnessFunc->hasThrows())
1904+
return AbstractTypeWitness(assocType, ctx.getNeverType());
1905+
1906+
// If it isn't 'rethrows', Failure == any Error.
1907+
if (!witnessFunc->getAttrs().hasAttribute<RethrowsAttr>())
1908+
return AbstractTypeWitness(assocType, ctx.getErrorExistentialType());
1909+
1910+
// Otherwise, we need to derive the Failure type from a type parameter
1911+
// that conforms to AsyncIteratorProtocol or AsyncSequence.
1912+
for (auto req : witnessFunc->getGenericSignature().getRequirements()) {
1913+
if (req.getKind() == RequirementKind::Conformance) {
1914+
auto proto = req.getProtocolDecl();
1915+
if (proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) ||
1916+
proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence)) {
1917+
auto failureAssocType = proto->getAssociatedType(ctx.getIdentifier("Failure"));
1918+
auto failureType = DependentMemberType::get(req.getFirstType(), failureAssocType);
1919+
return AbstractTypeWitness(assocType, dc->mapTypeIntoContext(failureType));
1920+
}
1921+
}
1922+
}
1923+
1924+
return AbstractTypeWitness(assocType, ctx.getErrorExistentialType());
1925+
}
1926+
1927+
break;
1928+
}
1929+
}
1930+
1931+
return llvm::None;
1932+
}
1933+
18481934
llvm::Optional<AbstractTypeWitness>
18491935
AssociatedTypeInference::computeDefaultTypeWitness(
18501936
AssociatedTypeDecl *assocType) const {
1937+
// Ignore the default for AsyncIteratorProtocol.Failure
1938+
if (isAsyncIteratorProtocolFailure(assocType))
1939+
return llvm::None;
1940+
18511941
// Go find a default definition.
18521942
auto *const defaultedAssocType = findDefaultedAssociatedType(
18531943
dc, dc->getSelfNominalTypeDecl(), assocType);
@@ -1942,7 +2032,11 @@ AssociatedTypeInference::computeAbstractTypeWitness(
19422032

19432033
// If there is a generic parameter of the named type, use that.
19442034
if (auto genericSig = dc->getGenericSignatureOfContext()) {
1945-
for (auto gp : genericSig.getInnermostGenericParams()) {
2035+
bool wantAllGenericParams = isAsyncIteratorProtocolFailure(assocType);
2036+
auto genericParams = wantAllGenericParams
2037+
? genericSig.getGenericParams()
2038+
: genericSig.getInnermostGenericParams();
2039+
for (auto gp : genericParams) {
19462040
// Packs cannot witness associated type requirements.
19472041
if (gp->isParameterPack())
19482042
continue;
@@ -2552,7 +2646,20 @@ void AssociatedTypeInference::findSolutionsRec(
25522646
// Filter out the associated types that remain unresolved.
25532647
SmallVector<AssociatedTypeDecl *, 4> stillUnresolved;
25542648
for (auto *const assocType : unresolvedAssocTypes) {
2555-
const auto typeWitness = typeWitnesses.begin(assocType);
2649+
auto typeWitness = typeWitnesses.begin(assocType);
2650+
2651+
// If we do not have a witness for AsyncIteratorProtocol.Failure,
2652+
// look for the witness to AsyncIteratorProtocol.next(). If it throws,
2653+
// use 'any Error'. Otherwise, use 'Never'.
2654+
if (typeWitness == typeWitnesses.end()) {
2655+
if (auto failureTypeWitness =
2656+
computeFailureTypeWitness(assocType, valueWitnesses)) {
2657+
typeWitnesses.insert(assocType,
2658+
{failureTypeWitness->getType(), reqDepth});
2659+
typeWitness = typeWitnesses.begin(assocType);
2660+
}
2661+
}
2662+
25562663
if (typeWitness == typeWitnesses.end()) {
25572664
stillUnresolved.push_back(assocType);
25582665
} else {

lib/Sema/TypeCheckEffects.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,17 @@ static bool classifyWitness(ModuleDecl *module,
171171

172172
case PolymorphicEffectKind::Always:
173173
// Witness always has the effect.
174+
175+
// If the witness's thrown type is explicitly specified as a type
176+
// parameter, then check whether the substituted type is `Never`.
177+
if (kind == EffectKind::Throws) {
178+
if (Type thrownError = witnessDecl->getThrownInterfaceType()) {
179+
if (thrownError->hasTypeParameter())
180+
thrownError = thrownError.subst(declRef.getSubstitutions());
181+
if (thrownError->isNever())
182+
return false;
183+
}
184+
}
174185
return true;
175186

176187
case PolymorphicEffectKind::Invalid:

stdlib/public/Concurrency/AsyncIteratorProtocol.swift

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,38 @@ import Swift
8787
/// a reference type.
8888
@available(SwiftStdlib 5.1, *)
8989
@rethrows
90-
public protocol AsyncIteratorProtocol {
90+
public protocol AsyncIteratorProtocol<Element, Failure> {
9191
associatedtype Element
92+
93+
/// The type of failure produced by iteration.
94+
associatedtype Failure: Error = any Error
95+
9296
/// Asynchronously advances to the next element and returns it, or ends the
9397
/// sequence if there is no next element.
94-
///
98+
///
9599
/// - Returns: The next element, if it exists, or `nil` to signal the end of
96100
/// the sequence.
97101
mutating func next() async throws -> Element?
102+
103+
/// Asynchronously advances to the next element and returns it, or ends the
104+
/// sequence if there is no next element.
105+
///
106+
/// - Returns: The next element, if it exists, or `nil` to signal the end of
107+
/// the sequence.
108+
@available(SwiftStdlib 5.11, *)
109+
mutating func _nextElement() async throws(Failure) -> Element?
110+
}
111+
112+
@available(SwiftStdlib 5.1, *)
113+
extension AsyncIteratorProtocol {
114+
/// Default implementation of `_nextElement()` in terms of `next()`, which is
115+
/// required to maintain backward compatibility with existing async iterators.
116+
@available(SwiftStdlib 5.11, *)
117+
public mutating func _nextElement() async throws(Failure) -> Element? {
118+
do {
119+
return try await next()
120+
} catch {
121+
throw error as! Failure
122+
}
123+
}
98124
}

stdlib/public/Concurrency/AsyncSequence.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,17 @@ import Swift
7373
///
7474
@available(SwiftStdlib 5.1, *)
7575
@rethrows
76-
public protocol AsyncSequence {
76+
public protocol AsyncSequence<Element, Failure> {
7777
/// The type of asynchronous iterator that produces elements of this
7878
/// asynchronous sequence.
7979
associatedtype AsyncIterator: AsyncIteratorProtocol where AsyncIterator.Element == Element
8080
/// The type of element produced by this asynchronous sequence.
8181
associatedtype Element
82+
83+
/// The type of errors produced when iteration over the sequence fails.
84+
associatedtype Failure: Error = AsyncIterator.Failure
85+
where AsyncIterator.Failure == Failure
86+
8287
/// Creates the asynchronous iterator that produces elements of this
8388
/// asynchronous sequence.
8489
///

stdlib/public/Concurrency/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ add_swift_target_library(swift_Concurrency ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} I
169169
SWIFT_COMPILE_FLAGS
170170
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
171171
-parse-stdlib
172-
-Xfrontend -enable-experimental-concurrency
172+
-enable-experimental-feature TypedThrows
173173
-diagnostic-style swift
174174
${SWIFT_RUNTIME_CONCURRENCY_SWIFT_FLAGS}
175175
${swift_concurrency_options}
@@ -256,6 +256,7 @@ if(SWIFT_SHOULD_BUILD_EMBEDDED_STDLIB AND SWIFT_SHOULD_BUILD_EMBEDDED_CONCURRENC
256256

257257
SWIFT_COMPILE_FLAGS
258258
${extra_swift_compile_flags} -enable-experimental-feature Embedded
259+
-enable-experimental-feature TypedThrows
259260
-parse-stdlib -DSWIFT_CONCURRENCY_EMBEDDED
260261
${SWIFT_RUNTIME_CONCURRENCY_SWIFT_FLAGS}
261262
C_COMPILE_FLAGS
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %target-swift-frontend -strict-concurrency=complete -emit-sil -o /dev/null %s -verify
2+
// REQUIRES: concurrency
3+
4+
@available(SwiftStdlib 5.1, *)
5+
struct S: AsyncSequence {
6+
typealias Element = Int
7+
struct AsyncIterator: AsyncIteratorProtocol {
8+
mutating func next() async -> Int? { nil }
9+
}
10+
11+
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
12+
}
13+
14+
@available(SwiftStdlib 5.1, *)
15+
struct TS: AsyncSequence {
16+
typealias Element = Int
17+
struct AsyncIterator: AsyncIteratorProtocol {
18+
mutating func next() async throws -> Int? { nil }
19+
}
20+
21+
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
22+
}
23+
24+
@available(SwiftStdlib 5.1, *)
25+
struct GenericTS<Failure: Error>: AsyncSequence {
26+
typealias Element = Int
27+
struct AsyncIterator: AsyncIteratorProtocol {
28+
mutating func next() async throws -> Int? { nil }
29+
}
30+
31+
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
32+
}
33+
34+
@available(SwiftStdlib 5.1, *)
35+
struct SequenceAdapter<Base: AsyncSequence>: AsyncSequence {
36+
typealias Element = Base.Element
37+
38+
struct AsyncIterator: AsyncIteratorProtocol {
39+
mutating func next() async rethrows -> Base.Element? { nil }
40+
}
41+
42+
func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
43+
}
44+
45+
enum MyError: Error {
46+
case fail
47+
}
48+
49+
@available(SwiftStdlib 5.1, *)
50+
func testAssocTypeInference(sf: S.Failure, tsf: TS.Failure, gtsf1: GenericTS<MyError>.Failure, adapter: SequenceAdapter<GenericTS<MyError>>.Failure) {
51+
let _: Int = sf // expected-error{{cannot convert value of type 'S.Failure' (aka 'Never') to specified type 'Int'}}
52+
let _: Int = tsf // expected-error{{cannot convert value of type 'TS.Failure' (aka 'any Error') to specified type 'Int'}}
53+
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.Failure' (aka 'MyError') to specified type 'Int'}}
54+
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<GenericTS<MyError>>.Failure' (aka 'MyError') to specified type 'Int'}}
55+
}
56+
57+
58+
@available(SwiftStdlib 5.1, *)
59+
func test(s: S) async {
60+
for await x in s { _ = x }
61+
}

test/api-digester/stability-concurrency-abi.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ Func _asyncLet_get_throwing(_:_:) has mangled name changing from '_Concurrency._
6060
Func _asyncLet_get_throwing(_:_:) has return type change from Builtin.RawPointer to ()
6161
Protocol Actor has added inherited protocol AnyActor
6262
Protocol Actor has generic signature change from <Self : AnyObject, Self : Swift.Sendable> to <Self : _Concurrency.AnyActor>
63+
Protocol AsyncIteratorProtocol has generic signature change from to <Self.Failure : Swift.Error>
64+
Protocol AsyncSequence has generic signature change from <Self.AsyncIterator : _Concurrency.AsyncIteratorProtocol, Self.Element == Self.AsyncIterator.Element> to <Self.AsyncIterator : _Concurrency.AsyncIteratorProtocol, Self.Element == Self.AsyncIterator.Element, Self.Failure == Self.AsyncIterator.Failure>
6365
Struct CheckedContinuation has removed conformance to UnsafeSendable
6466

6567
// SerialExecutor gained `enqueue(_: __owned Job)`, protocol requirements got default implementations

0 commit comments

Comments
 (0)