Skip to content

Commit 2c49700

Browse files
authored
Merge pull request #64054 from etcwilde/ewilde/startOnMainActor
Concurrency: swift_task_startOnMainActor
2 parents e969ff7 + f8e1272 commit 2c49700

File tree

6 files changed

+156
-0
lines changed

6 files changed

+156
-0
lines changed

include/swift/Runtime/Concurrency.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,9 @@ void swift_task_reportUnexpectedExecutor(
899899
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
900900
JobPriority swift_task_getCurrentThreadPriority(void);
901901

902+
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
903+
void swift_task_startOnMainActor(AsyncTask* job);
904+
902905
#if SWIFT_CONCURRENCY_COOPERATIVE_GLOBAL_EXECUTOR
903906

904907
/// Donate this thread to the global executor until either the

stdlib/public/CompatibilityOverride/CompatibilityOverrideConcurrency.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ OVERRIDE_TASK_STATUS(task_escalate, JobPriority,
382382
swift::, (AsyncTask *task, JobPriority newPriority),
383383
(task, newPriority))
384384

385+
OVERRIDE_TASK(task_startOnMainActor, void,
386+
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift),
387+
swift::, (AsyncTask *task), (task))
388+
385389
#undef OVERRIDE
386390
#undef OVERRIDE_ACTOR
387391
#undef OVERRIDE_TASK

stdlib/public/Concurrency/Task.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,17 @@ void (*swift::swift_task_asyncMainDrainQueue_hook)(
15941594
swift_task_asyncMainDrainQueue_original original,
15951595
swift_task_asyncMainDrainQueue_override compatOverride) = nullptr;
15961596

1597+
SWIFT_CC(swift)
1598+
static void swift_task_startOnMainActorImpl(AsyncTask* task) {
1599+
AsyncTask * originalTask = _swift_task_clearCurrent();
1600+
ExecutorRef mainExecutor = swift_task_getMainExecutor();
1601+
if (swift_task_getCurrentExecutor() != swift_task_getMainExecutor())
1602+
swift_Concurrency_fatalError(0, "Not on the main executor");
1603+
swift_retain(task);
1604+
swift_job_run(task, mainExecutor);
1605+
_swift_task_setCurrent(originalTask);
1606+
}
1607+
15971608
#define OVERRIDE_TASK COMPATIBILITY_OVERRIDE
15981609

15991610
#ifdef SWIFT_STDLIB_SUPPORT_BACK_DEPLOYMENT

stdlib/public/Concurrency/Task.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,44 @@ extension Task: Equatable {
186186
}
187187
}
188188

189+
@available(SwiftStdlib 5.9, *)
190+
extension Task where Failure == Error {
191+
@_spi(MainActorUtilities)
192+
@MainActor
193+
@available(SwiftStdlib 5.9, *)
194+
public static func startOnMainActor(
195+
priority: TaskPriority? = nil,
196+
@_inheritActorContext @_implicitSelfCapture _ work: __owned @Sendable @escaping @MainActor() async throws -> Success
197+
) -> Task<Success, Error> {
198+
let flags = taskCreateFlags(priority: priority, isChildTask: false,
199+
copyTaskLocals: true, inheritContext: true,
200+
enqueueJob: false,
201+
addPendingGroupTaskUnconditionally: false)
202+
let (task, _) = Builtin.createAsyncTask(flags, work)
203+
_startTaskOnMainActor(task)
204+
return Task<Success, Error>(task)
205+
}
206+
}
207+
208+
@available(SwiftStdlib 5.9, *)
209+
extension Task where Failure == Never {
210+
@_spi(MainActorUtilities)
211+
@MainActor
212+
@available(SwiftStdlib 5.9, *)
213+
public static func startOnMainActor(
214+
priority: TaskPriority? = nil,
215+
@_inheritActorContext @_implicitSelfCapture _ work: __owned @Sendable @escaping @MainActor() async -> Success
216+
) -> Task<Success, Never> {
217+
let flags = taskCreateFlags(priority: priority, isChildTask: false,
218+
copyTaskLocals: true, inheritContext: true,
219+
enqueueJob: false,
220+
addPendingGroupTaskUnconditionally: false)
221+
let (task, _) = Builtin.createAsyncTask(flags, work)
222+
_startTaskOnMainActor(task)
223+
return Task(task)
224+
}
225+
}
226+
189227
// ==== Task Priority ----------------------------------------------------------
190228

191229
/// The priority of a task.
@@ -880,6 +918,9 @@ extension UnsafeCurrentTask: Equatable {
880918
@_silgen_name("swift_task_getCurrent")
881919
func _getCurrentAsyncTask() -> Builtin.NativeObject?
882920

921+
@_silgen_name("swift_task_startOnMainActor")
922+
fileprivate func _startTaskOnMainActor(_ task: Builtin.NativeObject) -> Builtin.NativeObject?
923+
883924
@available(SwiftStdlib 5.1, *)
884925
@_silgen_name("swift_task_getJobFlags")
885926
func getJobFlags(_ task: Builtin.NativeObject) -> JobFlags
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift -Xfrontend -disable-availability-checking %s -o %t/a.out
3+
// RUN: %target-codesign %t/a.out
4+
// RUN: %target-run %t/a.out
5+
6+
// REQUIRES: executable_test
7+
// REQUIRES: concurrency
8+
// REQUIRES: concurrency_runtime
9+
// UNSUPPORTED: back_deployment_runtime
10+
// UNSUPPORTED: back_deploy_concurrency
11+
// UNSUPPORTED: use_os_stdlib
12+
// UNSUPPORTED: freestanding
13+
14+
import StdlibUnittest
15+
@_spi(MainActorUtilities) import _Concurrency
16+
17+
func doStuffAsync() async {
18+
await Task.sleep(500)
19+
}
20+
21+
let tests = TestSuite("StartOnMainActor")
22+
23+
tests.test("startOnMainActor") {
24+
// "global" variables for this test
25+
struct Globals {
26+
@MainActor
27+
static var ran = false
28+
}
29+
30+
@MainActor
31+
func run() async {
32+
Globals.ran = true
33+
await doStuffAsync()
34+
}
35+
36+
// enqueue item on the MainActor
37+
let t1 = Task { @MainActor in
38+
await Task.sleep(1000)
39+
}
40+
41+
expectFalse(Globals.ran)
42+
43+
// Run something with side-effects on the main actor
44+
let t2 = Task.startOnMainActor {
45+
return await run()
46+
}
47+
48+
expectTrue(Globals.ran)
49+
await t1.value
50+
await t2.value
51+
}
52+
53+
tests.test("throwing startOnMainActor") {
54+
// "global" variables for this test
55+
struct Globals {
56+
@MainActor
57+
static var ran = false
58+
}
59+
60+
struct StringError: Error {
61+
let message: String
62+
}
63+
64+
@MainActor
65+
func run() async throws {
66+
Globals.ran = true
67+
await doStuffAsync()
68+
throw StringError(message: "kablamo!")
69+
}
70+
71+
// enqueue item on the MainActor
72+
let t1 = Task { @MainActor in
73+
await Task.sleep(1000)
74+
}
75+
76+
expectFalse(Globals.ran)
77+
78+
// Run something with side-effects on the main actor
79+
let t2 = Task.startOnMainActor {
80+
return try await run()
81+
}
82+
83+
expectTrue(Globals.ran)
84+
await t1.value
85+
expectNil(try? await t2.value)
86+
}
87+
88+
await runAllTestsAsync()

unittests/runtime/CompatibilityOverrideConcurrency.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ static void swift_task_enqueueMainExecutor_override(
9797
Ran = true;
9898
}
9999

100+
SWIFT_CC(swift)
101+
static void swift_task_startOnMainActor_override(AsyncTask* task) {
102+
Ran = true;
103+
}
104+
100105
#ifdef RUN_ASYNC_MAIN_DRAIN_QUEUE_TEST
101106
[[noreturn]] SWIFT_CC(swift)
102107
static void swift_task_asyncMainDrainQueue_override_fn(
@@ -284,6 +289,10 @@ TEST_F(CompatibilityOverrideConcurrencyTest, test_swift_task_escalate) {
284289
swift_task_escalate(nullptr, {});
285290
}
286291

292+
TEST_F(CompatibilityOverrideConcurrencyTest, test_swift_startOnMainActorImpl) {
293+
swift_task_startOnMainActor(nullptr);
294+
}
295+
287296
#if RUN_ASYNC_MAIN_DRAIN_QUEUE_TEST
288297
TEST_F(CompatibilityOverrideConcurrencyTest, test_swift_task_asyncMainDrainQueue) {
289298

0 commit comments

Comments
 (0)