Skip to content

Commit e058e99

Browse files
committed
Async closures
1 parent 68efe53 commit e058e99

File tree

3 files changed

+468
-1
lines changed

3 files changed

+468
-1
lines changed

Sources/IdentifiableContinuation.swift

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ public func withIdentifiableContinuation<T>(
8989
}
9090
}
9191

92+
@inlinable
93+
public func withIdentifiableContinuation<T>(
94+
function: String = #function,
95+
body: (IdentifiableContinuation<T, Never>) async -> Void,
96+
onCancel: (IdentifiableContinuation<T, Never>.ID) async -> Void
97+
) async -> T {
98+
return await withoutActuallyEscaping(body, onCancel, result: T.self) {
99+
let state = AsyncLockedState(body: $0, onCancel: $1)
100+
return await state.startCheckedContinuation(function: function)
101+
}
102+
}
103+
92104
@inlinable
93105
public func withThrowingIdentifiableContinuation<T>(
94106
function: String = #function,
@@ -109,6 +121,18 @@ public func withThrowingIdentifiableContinuation<T>(
109121
}
110122
}
111123

124+
@inlinable
125+
public func withThrowingIdentifiableContinuation<T>(
126+
function: String = #function,
127+
body: (IdentifiableContinuation<T, Error>) async -> Void,
128+
onCancel: (IdentifiableContinuation<T, Error>.ID) async -> Void
129+
) async throws -> T {
130+
try await withoutActuallyEscaping(body, onCancel, result: T.self) {
131+
let state = AsyncLockedState(body: $0, onCancel: $1)
132+
return try await state.startCheckedThrowingContinuation(function: function)
133+
}
134+
}
135+
112136
@inlinable
113137
public func withIdentifiableUnsafeContinuation<T>(
114138
body: (IdentifiableContinuation<T, Never>) -> Void,
@@ -128,6 +152,17 @@ public func withIdentifiableUnsafeContinuation<T>(
128152
}
129153
}
130154

155+
@inlinable
156+
public func withIdentifiableUnsafeContinuation<T>(
157+
body: (IdentifiableContinuation<T, Never>) async -> Void,
158+
onCancel: (IdentifiableContinuation<T, Never>.ID) async -> Void
159+
) async -> T {
160+
await withoutActuallyEscaping(body, onCancel, result: T.self) {
161+
let state = AsyncLockedState(body: $0, onCancel: $1)
162+
return await state.startUnsafeContinuation()
163+
}
164+
}
165+
131166
@inlinable
132167
public func withThrowingIdentifiableUnsafeContinuation<T>(
133168
body: (IdentifiableContinuation<T, Error>) -> Void,
@@ -147,6 +182,17 @@ public func withThrowingIdentifiableUnsafeContinuation<T>(
147182
}
148183
}
149184

185+
@inlinable
186+
public func withThrowingIdentifiableUnsafeContinuation<T>(
187+
body: (IdentifiableContinuation<T, Error>) async -> Void,
188+
onCancel: (IdentifiableContinuation<T, Error>.ID) async -> Void
189+
) async throws -> T {
190+
try await withoutActuallyEscaping(body, onCancel, result: T.self) {
191+
let state = AsyncLockedState(body: $0, onCancel: $1)
192+
return try await state.startUnsafeThrowingContinuation()
193+
}
194+
}
195+
150196
public struct IdentifiableContinuation<T, E>: Sendable, Identifiable where E : Error {
151197

152198
public let id: ID
@@ -261,6 +307,165 @@ final class LockedState<T, Failure: Error>: @unchecked Sendable {
261307
}
262308
}
263309

310+
@usableFromInline
311+
final class AsyncLockedState<T, Failure: Error>: @unchecked Sendable {
312+
private let lock = NSLock()
313+
private var state: State
314+
private let body: (IdentifiableContinuation<T, Failure>) async -> Void
315+
private let onCancel: (IdentifiableContinuation<T, Failure>.ID) async -> Void
316+
317+
@usableFromInline
318+
struct State {
319+
@usableFromInline
320+
var bodyTask: Task<Void, Never>?
321+
@usableFromInline
322+
var cancelTask: Task<Void, Never>?
323+
@usableFromInline
324+
var didCompleteBody: Bool = false
325+
@usableFromInline
326+
var isCancelled: Bool = false
327+
@usableFromInline
328+
var didCancel: Bool = false
329+
}
330+
331+
@usableFromInline
332+
init(body: @escaping (IdentifiableContinuation<T, Failure>) async -> Void,
333+
onCancel: @escaping (IdentifiableContinuation<T, Failure>.ID) async -> Void) {
334+
self.state = State()
335+
self.body = body
336+
self.onCancel = onCancel
337+
}
338+
339+
@usableFromInline
340+
func startCheckedContinuation(function: String) async -> T where Failure == Never {
341+
let id = IdentifiableContinuation<T, Failure>.ID()
342+
return await withTaskCancellationHandler {
343+
let result = await withCheckedContinuation(function: function) {
344+
let continuation = IdentifiableContinuation(id: id, storage: .checked($0))
345+
start(with: continuation)
346+
}
347+
await waitForTasks()
348+
return result
349+
} onCancel: {
350+
cancel(withID: id)
351+
}
352+
}
353+
354+
@usableFromInline
355+
func startCheckedThrowingContinuation(function: String) async throws -> T where Failure == Error {
356+
let id = IdentifiableContinuation<T, Failure>.ID()
357+
return try await withTaskCancellationHandler {
358+
do {
359+
let result = try await withCheckedThrowingContinuation(function: function) {
360+
let continuation = IdentifiableContinuation(id: id, storage: .checked($0))
361+
start(with: continuation)
362+
}
363+
await waitForTasks()
364+
return result
365+
} catch {
366+
await waitForTasks()
367+
throw error
368+
}
369+
} onCancel: {
370+
cancel(withID: id)
371+
}
372+
}
373+
374+
@usableFromInline
375+
func startUnsafeContinuation() async -> T where Failure == Never {
376+
let id = IdentifiableContinuation<T, Failure>.ID()
377+
return await withTaskCancellationHandler {
378+
let result = await withUnsafeContinuation {
379+
let continuation = IdentifiableContinuation(id: id, storage: .unsafe($0))
380+
start(with: continuation)
381+
}
382+
await waitForTasks()
383+
return result
384+
} onCancel: {
385+
cancel(withID: id)
386+
}
387+
}
388+
389+
@usableFromInline
390+
func startUnsafeThrowingContinuation() async throws -> T where Failure == Error {
391+
let id = IdentifiableContinuation<T, Failure>.ID()
392+
return try await withTaskCancellationHandler {
393+
do {
394+
let result = try await withUnsafeThrowingContinuation {
395+
let continuation = IdentifiableContinuation(id: id, storage: .unsafe($0))
396+
start(with: continuation)
397+
}
398+
await waitForTasks()
399+
return result
400+
} catch {
401+
await waitForTasks()
402+
throw error
403+
}
404+
} onCancel: {
405+
cancel(withID: id)
406+
}
407+
}
408+
409+
@usableFromInline
410+
func start(with continuation: IdentifiableContinuation<T, Failure>) {
411+
let task = Task {
412+
await body(continuation)
413+
let performCancel = withCriticalRegion {
414+
$0.didCompleteBody = true
415+
if $0.isCancelled && !$0.didCancel {
416+
$0.didCancel = true
417+
return true
418+
} else {
419+
return false
420+
}
421+
}
422+
if performCancel {
423+
await onCancel(continuation.id)
424+
}
425+
}
426+
withCriticalRegion {
427+
$0.bodyTask = task
428+
}
429+
}
430+
431+
@usableFromInline
432+
func cancel(withID id: IdentifiableContinuation<T, Failure>.ID) {
433+
let task = Task {
434+
let performCancel = withCriticalRegion {
435+
$0.isCancelled = true
436+
$0.bodyTask?.cancel()
437+
if $0.didCompleteBody && !$0.didCancel {
438+
$0.didCancel = true
439+
return true
440+
} else {
441+
return false
442+
}
443+
}
444+
if performCancel {
445+
await onCancel(id)
446+
}
447+
}
448+
withCriticalRegion {
449+
$0.cancelTask = task
450+
}
451+
}
452+
453+
@usableFromInline
454+
func waitForTasks() async {
455+
let (bodyTask, cancelTask) = withCriticalRegion {
456+
($0.bodyTask, $0.cancelTask)
457+
}
458+
_ = await (bodyTask?.value, cancelTask?.value)
459+
}
460+
461+
@usableFromInline
462+
func withCriticalRegion<R>(_ critical: (inout State) throws -> R) rethrows -> R {
463+
lock.lock()
464+
defer { lock.unlock() }
465+
return try critical(&state)
466+
}
467+
}
468+
264469
@usableFromInline
265470
func withoutActuallyEscaping<T, Failure: Error, U>(
266471
_ c1: (IdentifiableContinuation<T, Failure>) -> Void,
@@ -273,3 +478,16 @@ func withoutActuallyEscaping<T, Failure: Error, U>(
273478
}
274479
}
275480
}
481+
482+
@usableFromInline
483+
func withoutActuallyEscaping<T, Failure: Error, U>(
484+
_ c1: (IdentifiableContinuation<T, Failure>) async -> Void,
485+
_ c2: (IdentifiableContinuation<T, Failure>.ID) async -> Void,
486+
result: U.Type,
487+
do body: (@escaping (IdentifiableContinuation<T, Failure>) async -> Void, @escaping (IdentifiableContinuation<T, Failure>.ID) async -> Void) async throws -> U) async rethrows -> U {
488+
try await withoutActuallyEscaping(c1) { (escapingC1) -> U in
489+
try await withoutActuallyEscaping(c2) { (escapingC2) -> U in
490+
try await body(escapingC1, escapingC2)
491+
}
492+
}
493+
}

0 commit comments

Comments
 (0)