Skip to content

Commit 1b6a18f

Browse files
committed
Adopt nextElement() in AsyncFlatMapSequence
`AsyncFlatMapSequence` is somewhat troublesome for typed throws, because it can produce errors from two different sources: the `Base` async sequence and the `SegmentOfResult` async sequence. However, we need to pick one `Failure` type for the `AsyncFlatMapSequence`, and there is no surface-language equivalent to the `errorUnion` operation described in SE-0413. So, we do something a little bit sneaky. We effectively require that `SegmentOfResult.Failure` either be equivalent to `Base.Failure` or be `Never`, such so that when the async sequence retruned by the closure throws, it throws the same thing as the base sequence. Therefore, the `Failure` type is defined to be `Base.Failure`. This property isn't enforced at the type level, but instead in the `AsyncSequence.flatMap` signatures: we replace the one signature that returned `AsyncFlatMapSequence` with three overloads that differ only in their generic requirements, adding: 1. `where SegmentOfResult.Failure == Failure` 2. `where SegmentOfResult.Failure == Never` 3. `where SegmentOfResult.Failure == Never, Failure == Never` (a tiebreaker between the two above) For cases where `SegmentOfResult.Failure` is neither `Never` nor `Failure`, overloading will choose the `flatMap` function that returns an `AsyncThrowingFlatMapSequence`. This can mean that existing code will choose a different overload and get a different type, but other than the type identity changing, the resulting sequence will behave the same way.
1 parent 3e726eb commit 1b6a18f

File tree

1 file changed

+169
-1
lines changed

1 file changed

+169
-1
lines changed

stdlib/public/Concurrency/AsyncFlatMapSequence.swift

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,38 @@ import Swift
1414

1515
@available(SwiftStdlib 5.1, *)
1616
extension AsyncSequence {
17+
/// Creates an asynchronous sequence that concatenates the results of calling
18+
/// the given transformation with each element of this sequence.
19+
///
20+
/// Use this method to receive a single-level asynchronous sequence when your
21+
/// transformation produces an asynchronous sequence for each element.
22+
///
23+
/// In this example, an asynchronous sequence called `Counter` produces `Int`
24+
/// values from `1` to `5`. The transforming closure takes the received `Int`
25+
/// and returns a new `Counter` that counts that high. For example, when the
26+
/// transform receives `3` from the base sequence, it creates a new `Counter`
27+
/// that produces the values `1`, `2`, and `3`. The `flatMap(_:)` method
28+
/// "flattens" the resulting sequence-of-sequences into a single
29+
/// `AsyncSequence`.
30+
///
31+
/// let stream = Counter(howHigh: 5)
32+
/// .flatMap { Counter(howHigh: $0) }
33+
/// for await number in stream {
34+
/// print(number, terminator: " ")
35+
/// }
36+
/// // Prints "1 1 2 1 2 3 1 2 3 4 1 2 3 4 5 "
37+
///
38+
/// - Parameter transform: A mapping closure. `transform` accepts an element
39+
/// of this sequence as its parameter and returns an `AsyncSequence`.
40+
/// - Returns: A single, flattened asynchronous sequence that contains all
41+
/// elements in all the asynchronous sequences produced by `transform`.
42+
@usableFromInline
43+
__consuming func flatMap<SegmentOfResult: AsyncSequence>(
44+
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
45+
) -> AsyncFlatMapSequence<Self, SegmentOfResult> {
46+
return AsyncFlatMapSequence(self, transform: transform)
47+
}
48+
1749
/// Creates an asynchronous sequence that concatenates the results of calling
1850
/// the given transformation with each element of this sequence.
1951
///
@@ -40,12 +72,88 @@ extension AsyncSequence {
4072
/// - Returns: A single, flattened asynchronous sequence that contains all
4173
/// elements in all the asynchronous sequences produced by `transform`.
4274
@preconcurrency
75+
@_alwaysEmitIntoClient
4376
@inlinable
4477
public __consuming func flatMap<SegmentOfResult: AsyncSequence>(
4578
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
46-
) -> AsyncFlatMapSequence<Self, SegmentOfResult> {
79+
) -> AsyncFlatMapSequence<Self, SegmentOfResult>
80+
where SegmentOfResult.Failure == Failure
81+
{
82+
return AsyncFlatMapSequence(self, transform: transform)
83+
}
84+
85+
/// Creates an asynchronous sequence that concatenates the results of calling
86+
/// the given transformation with each element of this sequence.
87+
///
88+
/// Use this method to receive a single-level asynchronous sequence when your
89+
/// transformation produces an asynchronous sequence for each element.
90+
///
91+
/// In this example, an asynchronous sequence called `Counter` produces `Int`
92+
/// values from `1` to `5`. The transforming closure takes the received `Int`
93+
/// and returns a new `Counter` that counts that high. For example, when the
94+
/// transform receives `3` from the base sequence, it creates a new `Counter`
95+
/// that produces the values `1`, `2`, and `3`. The `flatMap(_:)` method
96+
/// "flattens" the resulting sequence-of-sequences into a single
97+
/// `AsyncSequence`.
98+
///
99+
/// let stream = Counter(howHigh: 5)
100+
/// .flatMap { Counter(howHigh: $0) }
101+
/// for await number in stream {
102+
/// print(number, terminator: " ")
103+
/// }
104+
/// // Prints "1 1 2 1 2 3 1 2 3 4 1 2 3 4 5 "
105+
///
106+
/// - Parameter transform: A mapping closure. `transform` accepts an element
107+
/// of this sequence as its parameter and returns an `AsyncSequence`.
108+
/// - Returns: A single, flattened asynchronous sequence that contains all
109+
/// elements in all the asynchronous sequences produced by `transform`.
110+
@preconcurrency
111+
@_alwaysEmitIntoClient
112+
@inlinable
113+
public __consuming func flatMap<SegmentOfResult: AsyncSequence>(
114+
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
115+
) -> AsyncFlatMapSequence<Self, SegmentOfResult>
116+
where SegmentOfResult.Failure == Never
117+
{
47118
return AsyncFlatMapSequence(self, transform: transform)
48119
}
120+
121+
/// Creates an asynchronous sequence that concatenates the results of calling
122+
/// the given transformation with each element of this sequence.
123+
///
124+
/// Use this method to receive a single-level asynchronous sequence when your
125+
/// transformation produces an asynchronous sequence for each element.
126+
///
127+
/// In this example, an asynchronous sequence called `Counter` produces `Int`
128+
/// values from `1` to `5`. The transforming closure takes the received `Int`
129+
/// and returns a new `Counter` that counts that high. For example, when the
130+
/// transform receives `3` from the base sequence, it creates a new `Counter`
131+
/// that produces the values `1`, `2`, and `3`. The `flatMap(_:)` method
132+
/// "flattens" the resulting sequence-of-sequences into a single
133+
/// `AsyncSequence`.
134+
///
135+
/// let stream = Counter(howHigh: 5)
136+
/// .flatMap { Counter(howHigh: $0) }
137+
/// for await number in stream {
138+
/// print(number, terminator: " ")
139+
/// }
140+
/// // Prints "1 1 2 1 2 3 1 2 3 4 1 2 3 4 5 "
141+
///
142+
/// - Parameter transform: A mapping closure. `transform` accepts an element
143+
/// of this sequence as its parameter and returns an `AsyncSequence`.
144+
/// - Returns: A single, flattened asynchronous sequence that contains all
145+
/// elements in all the asynchronous sequences produced by `transform`.
146+
@preconcurrency
147+
@_alwaysEmitIntoClient
148+
@inlinable
149+
public __consuming func flatMap<SegmentOfResult: AsyncSequence>(
150+
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
151+
) -> AsyncFlatMapSequence<Self, SegmentOfResult>
152+
where SegmentOfResult.Failure == Never, Failure == Never
153+
{
154+
return AsyncFlatMapSequence(self, transform: transform)
155+
}
156+
49157
}
50158

51159
/// An asynchronous sequence that concatenates the results of calling a given
@@ -75,6 +183,14 @@ extension AsyncFlatMapSequence: AsyncSequence {
75183
/// The flat map sequence produces the type of element in the asynchronous
76184
/// sequence produced by the `transform` closure.
77185
public typealias Element = SegmentOfResult.Element
186+
/// The type of error produced by this asynchronous sequence.
187+
///
188+
/// The flat map sequence produces the type of error in the base asynchronous
189+
/// sequence. By construction, the sequence produced by the `transform`
190+
/// closure must either produce this type of error or not produce errors
191+
/// at all.
192+
@available(SwiftStdlib 5.11, *)
193+
public typealias Failure = Base.Failure
78194
/// The type of iterator that produces elements of the sequence.
79195
public typealias AsyncIterator = Iterator
80196

@@ -148,6 +264,58 @@ extension AsyncFlatMapSequence: AsyncSequence {
148264
}
149265
return nil
150266
}
267+
268+
/// Produces the next element in the flat map sequence.
269+
///
270+
/// This iterator calls `nextElement()` on its base iterator; if this call
271+
/// returns `nil`, `nextElement()` returns `nil`. Otherwise, `nextElement()`
272+
/// calls the transforming closure on the received element, takes the
273+
/// resulting asynchronous sequence, and creates an asynchronous iterator
274+
/// from it. `nextElement()` then consumes values from this iterator until
275+
/// it terminates. At this point, `nextElement()` is ready to receive the
276+
/// next value from the base sequence.
277+
@available(SwiftStdlib 5.11, *)
278+
@inlinable
279+
public mutating func nextElement() async throws(Failure) -> SegmentOfResult.Element? {
280+
while !finished {
281+
if var iterator = currentIterator {
282+
do throws(any Error) {
283+
let optElement = try await iterator.nextElement()
284+
guard let element = optElement else {
285+
currentIterator = nil
286+
continue
287+
}
288+
// restore the iterator since we just mutated it with next
289+
currentIterator = iterator
290+
return element
291+
} catch {
292+
finished = true
293+
throw error as! Failure
294+
}
295+
} else {
296+
let optItem = try await baseIterator.nextElement()
297+
guard let item = optItem else {
298+
finished = true
299+
return nil
300+
}
301+
do throws(any Error) {
302+
let segment = await transform(item)
303+
var iterator = segment.makeAsyncIterator()
304+
let optElement = try await iterator.nextElement()
305+
guard let element = optElement else {
306+
currentIterator = nil
307+
continue
308+
}
309+
currentIterator = iterator
310+
return element
311+
} catch {
312+
finished = true
313+
throw error as! Failure
314+
}
315+
}
316+
}
317+
return nil
318+
}
151319
}
152320

153321
@inlinable

0 commit comments

Comments
 (0)