Skip to content

Commit 4b92467

Browse files
committed
[Concurrency] Use a single atomic for future wait queue.
Use a single atomic for the wait queue that combines the status with the first task in the queue. Address race conditions in waiting and completing the future. Thanks to John for setting the direction here for me.
1 parent 0a07f18 commit 4b92467

File tree

3 files changed

+86
-57
lines changed

3 files changed

+86
-57
lines changed

include/swift/ABI/Task.h

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,34 @@ class AsyncTask : public HeapObject, public Job {
261261
Error,
262262
};
263263

264-
private:
265-
/// Status of the future.
266-
std::atomic<Status> status;
264+
/// An item within the wait queue, which includes the status and the
265+
/// head of the list of tasks.
266+
struct WaitQueueItem {
267+
/// Mask used for the low status bits in a wait queue item.
268+
const uintptr_t statusMask = 0x03;
269+
270+
uintptr_t storage;
271+
272+
Status getStatus() const {
273+
return static_cast<Status>(storage & statusMask);
274+
}
275+
276+
AsyncTask *getTask() const {
277+
return reinterpret_cast<AsyncTask *>(storage & ~statusMask);
278+
}
279+
280+
static WaitQueueItem get(Status status, AsyncTask *task) {
281+
return WaitQueueItem{
282+
reinterpret_cast<uintptr_t>(task) | static_cast<uintptr_t>(status)};
283+
}
284+
};
267285

286+
private:
268287
/// Queue containing all of the tasks that are waiting in `get()`.
269-
std::atomic<AsyncTask*> waitQueue;
288+
///
289+
/// The low bits contain the status, the rest of the pointer is the
290+
/// AsyncTask.
291+
std::atomic<WaitQueueItem> waitQueue;
270292

271293
/// The type of the result that will be produced by the future.
272294
const Metadata *resultType;
@@ -285,8 +307,9 @@ class AsyncTask : public HeapObject, public Job {
285307
public:
286308
FutureFragment(
287309
const Metadata *resultType, size_t resultOffset, size_t errorOffset)
288-
: status(Status::Success), waitQueue(nullptr), resultType(resultType),
289-
resultOffset(resultOffset), errorOffset(errorOffset) { }
310+
: waitQueue(WaitQueueItem::get(Status::Executing, nullptr)),
311+
resultType(resultType), resultOffset(resultOffset),
312+
errorOffset(errorOffset) { }
290313

291314
/// Destroy the storage associated with the future.
292315
void destroy();
@@ -331,10 +354,10 @@ class AsyncTask : public HeapObject, public Job {
331354
FutureFragment::Status waitFuture(AsyncTask *waitingTask);
332355

333356
/// Complete this future.
334-
void completeFuture(AsyncContext *context);
335-
336-
/// Schedule waiting tasks now that the future has completed.
337-
void scheduleWaitingTasks(ExecutorRef executor);
357+
///
358+
/// Upon completion, any waiting tasks will be scheduled on the given
359+
/// executor.
360+
void completeFuture(AsyncContext *context, ExecutorRef executor);
338361

339362
static bool classof(const Job *job) {
340363
return job->isAsyncTask();

include/swift/Runtime/Concurrency.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ AsyncTaskAndContext swift_task_create_future(
6868
const AsyncFunctionPointer<void()> *function,
6969
size_t resultOffset, size_t errorOffset);
7070

71-
/// Create a task object with no future which will run the given
71+
/// Create a task object with a future which will run the given
7272
/// function.
7373
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
7474
AsyncTaskAndContext swift_task_create_future_f(

stdlib/public/Concurrency/Task.cpp

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ size_t FutureFragment::fragmentSize(const Metadata *resultType) {
2828
}
2929

3030
void FutureFragment::destroy() {
31-
switch (status.load()) {
31+
auto queueHead = waitQueue.load(std::memory_order_acquire);
32+
switch (queueHead.getStatus()) {
3233
case Status::Executing:
3334
assert(false && "destroying a task that never completed");
3435

@@ -43,41 +44,49 @@ void FutureFragment::destroy() {
4344
}
4445

4546
FutureFragment::Status AsyncTask::waitFuture(AsyncTask *waitingTask) {
47+
using Status = FutureFragment::Status;
48+
using WaitQueueItem = FutureFragment::WaitQueueItem;
49+
4650
assert(isFuture());
4751
auto fragment = futureFragment();
4852

49-
auto currentStatus = fragment->status.load();
50-
switch (currentStatus) {
51-
case FutureFragment::Status::Error:
52-
case FutureFragment::Status::Success:
53-
// The task is done; we don't need to wait.
54-
return currentStatus;
55-
56-
case FutureFragment::Status::Executing:
57-
break;
58-
}
59-
60-
// Put the waiting task at the beginning of the wait queue.
53+
auto queueHead = fragment->waitQueue.load(std::memory_order_acquire);
6154
while (true) {
62-
waitingTask->NextWaitingTask = fragment->waitQueue.load();
63-
if (fragment->waitQueue.compare_exchange_strong(
64-
waitingTask->NextWaitingTask, waitingTask)) {
55+
switch (queueHead.getStatus()) {
56+
case Status::Error:
57+
case Status::Success:
58+
// The task is done; we don't need to wait.
59+
return queueHead.getStatus();
60+
61+
case Status::Executing:
62+
// Task is now complete. We'll need to add ourselves to the queue.
63+
break;
64+
}
65+
66+
// Put the waiting task at the beginning of the wait queue.
67+
waitingTask->NextWaitingTask = queueHead.getTask();
68+
auto newQueueHead = WaitQueueItem::get(Status::Executing, waitingTask);
69+
if (fragment->waitQueue.compare_exchange_weak(
70+
queueHead, newQueueHead, std::memory_order_release,
71+
std::memory_order_acquire)) {
6572
// Escalate the priority of this task based on the priority
6673
// of the waiting task.
6774
swift_task_escalate(this, waitingTask->Flags.getPriority());
68-
break;
75+
return FutureFragment::Status::Executing;
6976
}
7077
}
71-
72-
return FutureFragment::Status::Executing;
7378
}
7479

75-
void AsyncTask::completeFuture(AsyncContext *context) {
80+
void AsyncTask::completeFuture(AsyncContext *context, ExecutorRef executor) {
81+
using Status = FutureFragment::Status;
82+
using WaitQueueItem = FutureFragment::WaitQueueItem;
83+
7684
assert(isFuture());
7785
auto fragment = futureFragment();
7886
auto storagePtr = fragment->getStoragePtr();
7987

8088
// Check for an error.
89+
bool hadErrorResult = false;
8190
if (unsigned errorOffset = fragment->errorOffset) {
8291
// Find the error object in the context.
8392
auto errorPtrPtr = reinterpret_cast<char *>(context) + errorOffset;
@@ -86,42 +95,42 @@ void AsyncTask::completeFuture(AsyncContext *context) {
8695
// If there is an error, take it and we're done.
8796
if (errorObject) {
8897
*reinterpret_cast<OpaqueValue **>(storagePtr) = errorObject;
89-
fragment->status = FutureFragment::Status::Error;
90-
return;
98+
hadErrorResult = true;
9199
}
92100
}
93101

94-
// Take the success value.
95-
auto resultPtr = reinterpret_cast<OpaqueValue *>(
96-
reinterpret_cast<char *>(context) + fragment->resultOffset);
97-
fragment->resultType->vw_initializeWithTake(storagePtr, resultPtr);
98-
fragment->status = FutureFragment::Status::Success;
99-
}
100-
101-
void AsyncTask::scheduleWaitingTasks(ExecutorRef executor) {
102-
assert(isFuture());
103-
auto fragment = futureFragment();
104-
105-
auto waitingTask = fragment->waitQueue.load();
106-
107-
// Unnecessary, but useful for the assertion at the end.
108-
fragment->waitQueue = nullptr;
102+
if (!hadErrorResult) {
103+
// Take the success value.
104+
auto resultPtr = reinterpret_cast<OpaqueValue *>(
105+
reinterpret_cast<char *>(context) + fragment->resultOffset);
106+
fragment->resultType->vw_initializeWithTake(storagePtr, resultPtr);
107+
}
109108

109+
// Update the status to signal completion.
110+
auto newQueueHead = WaitQueueItem::get(
111+
hadErrorResult ? Status::Error : Status::Success,
112+
nullptr
113+
);
114+
auto queueHead = fragment->waitQueue.exchange(
115+
newQueueHead, std::memory_order_acquire);
116+
assert(queueHead.getStatus() == Status::Executing);
117+
118+
// Notify every
119+
auto waitingTask = queueHead.getTask();
110120
while (waitingTask) {
111121
// Find the next waiting task.
112122
auto nextWaitingTask = waitingTask->NextWaitingTask;
113123

114124
// Remove this task from the list.
115125
waitingTask->NextWaitingTask = nullptr;
116126

117-
// TODO: schedule this task on the executor.
127+
// TODO: schedule this task on the executor rather than running it
128+
// directly.
129+
waitingTask->run(executor);
118130

119131
// Move to the next task.
120132
waitingTask = nextWaitingTask;
121133
}
122-
123-
// Nobody should be able to add anything to the queue in this function.
124-
assert(!fragment->waitQueue.load());
125134
}
126135

127136
SWIFT_CC(swift)
@@ -171,12 +180,9 @@ static void completeTask(AsyncTask *task, ExecutorRef executor,
171180
// to wait for the object to be destroyed.
172181
_swift_task_alloc_destroy(task);
173182

183+
// Complete the future.
174184
if (task->isFuture()) {
175-
// Complete the future, taking the result from the context.
176-
task->completeFuture(context);
177-
178-
// Schedule tasks that are waiting on the future.
179-
task->scheduleWaitingTasks(executor);
185+
task->completeFuture(context, executor);
180186
}
181187

182188
// TODO: set something in the status?

0 commit comments

Comments
 (0)