Skip to content

Commit 655bd67

Browse files
committed
[Concurrency] Task priority escalation handler API
1 parent 18c2584 commit 655bd67

File tree

6 files changed

+151
-6
lines changed

6 files changed

+151
-6
lines changed

include/swift/ABI/TaskStatus.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ class CancellationNotificationStatusRecord : public TaskStatusRecord {
242242
: TaskStatusRecord(TaskStatusRecordKind::CancellationNotification),
243243
Function(fn), Argument(arg) {}
244244

245-
void run() { Function(Argument); }
245+
void run() {
246+
Function(Argument);
247+
}
246248

247249
static bool classof(const TaskStatusRecord *record) {
248250
return record->getKind() == TaskStatusRecordKind::CancellationNotification;
@@ -259,7 +261,7 @@ class CancellationNotificationStatusRecord : public TaskStatusRecord {
259261
/// subsequently used.
260262
class EscalationNotificationStatusRecord : public TaskStatusRecord {
261263
public:
262-
using FunctionType = void(void *, JobPriority);
264+
using FunctionType = SWIFT_CC(swift) void(JobPriority, SWIFT_CONTEXT void *);
263265

264266
private:
265267
FunctionType *__ptrauth_swift_escalation_notification_function Function;
@@ -268,9 +270,12 @@ class EscalationNotificationStatusRecord : public TaskStatusRecord {
268270
public:
269271
EscalationNotificationStatusRecord(FunctionType *fn, void *arg)
270272
: TaskStatusRecord(TaskStatusRecordKind::EscalationNotification),
271-
Function(fn), Argument(arg) {}
273+
Function(fn), Argument(arg) {
274+
}
272275

273-
void run(JobPriority newPriority) { Function(Argument, newPriority); }
276+
void run(JobPriority newPriority) {
277+
Function(newPriority, Argument);
278+
}
274279

275280
static bool classof(const TaskStatusRecord *record) {
276281
return record->getKind() == TaskStatusRecordKind::EscalationNotification;

include/swift/Runtime/Concurrency.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,18 @@ SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
609609
void swift_task_removeCancellationHandler(
610610
CancellationNotificationStatusRecord *record);
611611

612+
/// Create and add an escalation record to the task.
613+
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
614+
EscalationNotificationStatusRecord*
615+
swift_task_addEscalationHandler(
616+
EscalationNotificationStatusRecord::FunctionType handler,
617+
void *handlerContext);
618+
619+
/// Remove the passed cancellation record from the task.
620+
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
621+
void swift_task_removeEscalationHandler(
622+
EscalationNotificationStatusRecord *record);
623+
612624
/// Create a NullaryContinuationJob from a continuation.
613625
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
614626
NullaryContinuationJob*

stdlib/public/CompatibilityOverride/CompatibilityOverrideConcurrency.def

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,17 @@ OVERRIDE_TASK(task_removeCancellationHandler, void,
206206
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift), swift::,
207207
(CancellationNotificationStatusRecord *record), (record))
208208

209+
OVERRIDE_TASK(task_addEscalationHandler,
210+
EscalationNotificationStatusRecord *,
211+
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift), swift::,
212+
(EscalationNotificationStatusRecord::FunctionType handler,
213+
void *context),
214+
(handler, context))
215+
216+
OVERRIDE_TASK(task_removeEscalationHandler, void,
217+
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift), swift::,
218+
(EscalationNotificationStatusRecord *record), (record))
219+
209220
OVERRIDE_TASK(task_createNullaryContinuationJob, NullaryContinuationJob *,
210221
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift), swift::,
211222
(size_t priority,

stdlib/public/Concurrency/Task.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,6 +1766,31 @@ static void swift_task_removeCancellationHandlerImpl(
17661766
swift_task_dealloc(record);
17671767
}
17681768

1769+
SWIFT_CC(swift)
1770+
static EscalationNotificationStatusRecord*
1771+
swift_task_addEscalationHandlerImpl(
1772+
EscalationNotificationStatusRecord::FunctionType handler,
1773+
void *context) {
1774+
void *allocation =
1775+
swift_task_alloc(sizeof(EscalationNotificationStatusRecord));
1776+
auto unsigned_handler = swift_auth_code(handler, 3848); // FIXME: fix this number for ptrauth
1777+
auto *record = ::new (allocation)
1778+
EscalationNotificationStatusRecord(handler, context);
1779+
1780+
addStatusRecordToSelf(record, [&](ActiveTaskStatus oldStatus, ActiveTaskStatus& newStatus) {
1781+
return true;
1782+
});
1783+
1784+
return record;
1785+
}
1786+
1787+
SWIFT_CC(swift)
1788+
static void swift_task_removeEscalationHandlerImpl(
1789+
EscalationNotificationStatusRecord *record) {
1790+
removeStatusRecordFromSelf(record);
1791+
swift_task_dealloc(record);
1792+
}
1793+
17691794
SWIFT_CC(swift)
17701795
static NullaryContinuationJob*
17711796
swift_task_createNullaryContinuationJobImpl(

stdlib/public/Concurrency/Task.swift

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ extension Task {
929929
/// - task: the task which to escalate the priority of
930930
/// - newPriority: the new priority the task should continue executing on
931931
@available(SwiftStdlib 9999, *)
932-
public static func escalatePriority(_ task: UnsafeCurrentTask, to newPriority: TaskPriority) {
932+
public static func escalatePriority(_ task: Task, to newPriority: TaskPriority) {
933933
_taskEscalate(task._task, newPriority: newPriority.rawValue)
934934
}
935935

@@ -964,6 +964,56 @@ extension Task {
964964
}
965965
}
966966

967+
/// Runs the passed `operation` while registering a task priority escalation handler.
968+
/// The handler will be triggered concurrently to the current task if the current
969+
/// is subject to priority escalation.
970+
///
971+
/// The handler may perform additional actions upon priority escalation,
972+
/// but cannot influence how the escalation influences the task, i.e. the task's
973+
/// priority will be escalated regardless of actions performed in the handler.
974+
///
975+
/// If multiple task escalation handlers are nester they will all be triggered.
976+
///
977+
/// Task escalation propagates through structured concurrency child-tasks.
978+
///
979+
/// - Parameters:
980+
/// - operation: the operation during which to listen for priority escalation
981+
/// - handler: handler to invoke, concurrently to `operation`,
982+
/// when priority escalation happens
983+
/// - Returns: the value returned by `operation`
984+
/// - Throws: when the `operation` throws an error
985+
@available(SwiftStdlib 9999, *)
986+
public func withTaskPriorityEscalationHandler<T, E>(
987+
operation: () async throws(E) -> T,
988+
onEscalate handler: @Sendable (TaskPriority) -> Void,
989+
isolation: isolated (any Actor)? = #isolation
990+
) async throws(E) -> T {
991+
// NOTE: We have to create the closure beforehand as otherwise it seems
992+
// the task-local allocator may be used and we end up violating stack-discipline
993+
// when releasing the handler closure vs. the record.
994+
let handler0: (UInt8) -> Void = {
995+
handler(TaskPriority(rawValue: $0))
996+
}
997+
let record = _taskAddEscalationHandler(handler: handler0)
998+
defer { _taskRemoveEscalationHandler(record: record) }
999+
1000+
return try await operation()
1001+
}
1002+
1003+
@usableFromInline
1004+
@available(SwiftStdlib 9999, *)
1005+
@_silgen_name("swift_task_addEscalationHandler")
1006+
func _taskAddEscalationHandler(
1007+
handler: (UInt8) -> Void
1008+
) -> UnsafeRawPointer /*EscalationNotificationStatusRecord*/
1009+
1010+
@usableFromInline
1011+
@available(SwiftStdlib 9999, *)
1012+
@_silgen_name("swift_task_removeEscalationHandler")
1013+
func _taskRemoveEscalationHandler(
1014+
record: UnsafeRawPointer /*EscalationNotificationStatusRecord*/
1015+
)
1016+
9671017
// ==== UnsafeCurrentTask ------------------------------------------------------
9681018

9691019
/// Calls a closure with an unsafe reference to the current task.

test/Concurrency/async_task_escalate_priority.swift

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ import Darwin
2626
@preconcurrency import Dispatch
2727
import StdlibUnittest
2828

29+
func TESTTEST_taskAddCancellationHandler(handler: () -> Void) {
30+
handler()
31+
}
32+
func TESTTEST_taskAddEscalationHandler(handler: (TaskPriority) -> Void, t: TaskPriority) { handler(t) }
33+
2934
func loopUntil(priority: TaskPriority) async {
30-
var loops = 100
35+
var loops = 10
3136
var currentPriority = Task.currentPriority
3237
while (currentPriority != priority) {
3338
print("Current priority = \(currentPriority) != \(priority)")
@@ -99,6 +104,43 @@ func testNestedTaskPriority(basePri: TaskPriority, curPri: TaskPriority) async {
99104
Task.escalatePriority(task, to: .default)
100105
sem2.wait()
101106
}
107+
108+
tests.test("Trigger task escalation handler") {
109+
let sem1 = DispatchSemaphore(value: 0)
110+
let sem2 = DispatchSemaphore(value: 0)
111+
let semEscalated = DispatchSemaphore(value: 0)
112+
113+
let task = Task(priority: .background) {
114+
let _ = expectedBasePri(priority: .background)
115+
116+
await withTaskPriorityEscalationHandler {
117+
print("in withTaskPriorityEscalationHandler, Task.currentPriority = \(Task.currentPriority)")
118+
119+
// Wait until task is running before asking to be escalated
120+
sem1.signal()
121+
sleep(1)
122+
123+
await loopUntil(priority: .default)
124+
print("in withTaskPriorityEscalationHandler, after loop, Task.currentPriority = \(Task.currentPriority)")
125+
} onEscalate: { newPriority in
126+
print("in onEscalate Task.currentPriority = \(Task.currentPriority)")
127+
print("in onEscalate newPriority = \(newPriority)")
128+
precondition(newPriority == .default)
129+
semEscalated.signal()
130+
}
131+
132+
print("Current priority = \(Task.currentPriority)")
133+
print("after withTaskPriorityEscalationHandler")
134+
sem2.signal()
135+
}
136+
137+
// Wait till child runs and asks to be escalated
138+
sem1.wait()
139+
task.cancel() // just checking the records don't stomp onto each other somehow
140+
Task.escalatePriority(task, to: .default)
141+
semEscalated.wait()
142+
sem2.wait()
143+
}
102144
}
103145

104146
await runAllTestsAsync()

0 commit comments

Comments
 (0)