Skip to content

Commit 009f64b

Browse files
committed
Must not modify waitingTask context outside lock
1 parent a747dc6 commit 009f64b

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,9 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
366366
///
367367
/// \param bodyError error thrown by the body of a with...TaskGroup method
368368
/// \param waitingTask the task waiting on the group
369+
/// \param rawContext used to resume the waiting task
369370
/// \return how the waiting task should be handled, e.g. must wait or can be completed immediately
370-
PollResult waitAll(SwiftError* bodyError, AsyncTask *waitingTask);
371+
PollResult waitAll(SwiftError* bodyError, AsyncTask *waitingTask, AsyncContext* rawContext);
371372

372373
// Enqueue the completed task onto ready queue if there are no waiting tasks yet
373374
virtual void enqueueCompletedTask(AsyncTask *completedTask, bool hadErrorResult) = 0;
@@ -378,6 +379,7 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
378379
// ==== Status manipulation -------------------------------------------------
379380

380381
TaskGroupStatus statusLoadRelaxed() const;
382+
TaskGroupStatus statusLoadAcquire() const;
381383

382384
std::string statusString() const;
383385

@@ -409,6 +411,10 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
409411
/// Remove waiting status bit.
410412
TaskGroupStatus statusRemoveWaitingRelease();
411413

414+
/// Mark the waiting status bit.
415+
/// A waiting task MUST have been already enqueued in the `waitQueue`.
416+
TaskGroupStatus statusMarkWaitingAssumeRelease();
417+
412418
/// Cancels the group and returns true if was already cancelled before.
413419
/// After this function returns, the group is guaranteed to be cancelled.
414420
///
@@ -521,7 +527,7 @@ struct TaskGroupStatus {
521527
/// TaskGroupStatus{ C:{cancelled} W:{waiting task} R:{ready tasks} P:{pending tasks} {binary repr} }
522528
/// If discarding results:
523529
/// TaskGroupStatus{ C:{cancelled} W:{waiting task} P:{pending tasks} {binary repr} }
524-
std::string to_string(const TaskGroupBase* _Nonnull group) {
530+
std::string to_string(const TaskGroupBase* group) {
525531
std::string str;
526532
str.append("TaskGroupStatus{ ");
527533
str.append("C:"); // cancelled
@@ -548,7 +554,7 @@ struct TaskGroupStatus {
548554
bool TaskGroupBase::statusCompletePendingReadyWaiting(TaskGroupStatus &old) {
549555
return status.compare_exchange_strong(
550556
old.status, old.completingPendingReadyWaiting(this).status,
551-
/*success*/ std::memory_order_relaxed,
557+
/*success*/ std::memory_order_release,
552558
/*failure*/ std::memory_order_relaxed);
553559
}
554560

@@ -561,6 +567,10 @@ TaskGroupStatus TaskGroupBase::statusLoadRelaxed() const {
561567
return TaskGroupStatus{status.load(std::memory_order_relaxed)};
562568
}
563569

570+
TaskGroupStatus TaskGroupBase::statusLoadAcquire() const {
571+
return TaskGroupStatus{status.load(std::memory_order_acquire)};
572+
}
573+
564574
std::string TaskGroupBase::statusString() const {
565575
return statusLoadRelaxed().to_string(this);
566576
}
@@ -580,6 +590,12 @@ TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeAcquire() {
580590
return TaskGroupStatus{old | TaskGroupStatus::waiting};
581591
}
582592

593+
TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeRelease() {
594+
auto old = status.fetch_or(TaskGroupStatus::waiting,
595+
std::memory_order_release);
596+
return TaskGroupStatus{old | TaskGroupStatus::waiting};
597+
}
598+
583599
TaskGroupStatus TaskGroupBase::statusRemoveWaitingRelease() {
584600
auto old = status.fetch_and(~TaskGroupStatus::waiting,
585601
std::memory_order_release);
@@ -709,18 +725,6 @@ class DiscardingTaskGroup: public TaskGroupBase {
709725
return true;
710726
}
711727

712-
/// Returns *assumed* new status, including the just performed +1.
713-
TaskGroupStatus statusMarkWaitingAssumeAcquire() {
714-
auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_acquire);
715-
return TaskGroupStatus{old | TaskGroupStatus::waiting};
716-
}
717-
718-
TaskGroupStatus statusRemoveWaitingRelease() {
719-
auto old = status.fetch_and(~TaskGroupStatus::waiting,
720-
std::memory_order_release);
721-
return TaskGroupStatus{old};
722-
}
723-
724728
/// Returns *assumed* new status.
725729
TaskGroupStatus statusAddReadyAssumeAcquire(const DiscardingTaskGroup *group) {
726730
assert(group->isDiscardingResults());
@@ -1152,7 +1156,7 @@ void AccumulatingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *contex
11521156
hadErrorResult = true;
11531157
}
11541158

1155-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "ready: %d, pending: %u",
1159+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "ready: %d, pending: %llu",
11561160
assumed.readyTasks(this), assumed.pendingTasks(this));
11571161

11581162
// ==== a) has waiting task, so let us complete it right away
@@ -1205,13 +1209,14 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context)
12051209

12061210
/// If we're the last task we've been waiting for, and there is a waiting task on the group
12071211
bool lastPendingTaskAndWaitingTask =
1208-
assumed.pendingTasks(this) == 1 && assumed.hasWaitingTask();
1212+
assumed.pendingTasks(this) == 1 &&
1213+
assumed.hasWaitingTask();
12091214

12101215
// Immediately decrement the pending count.
12111216
// We can do this, since in this mode there is no ready count to keep track of,
12121217
// and we immediately discard the result.
1213-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "discard result, hadError:%d, was pending:%llu",
1214-
hadErrorResult, assumed.pendingTasks(this));
1218+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "discard result, hadError:%d, was pending:%llu, status = %s",
1219+
hadErrorResult, assumed.pendingTasks(this), assumed.to_string(this).c_str());
12151220
// If this was the last pending task, and there is a waiting task (from waitAll),
12161221
// we must resume the task; but not otherwise. There cannot be any waiters on next()
12171222
// while we're discarding results.
@@ -1301,6 +1306,8 @@ void TaskGroupBase::resumeWaitingTask(
13011306
if (statusCompletePendingReadyWaiting(assumed)) {
13021307
// Run the task.
13031308
auto result = PollResult::get(completedTask, hadErrorResult);
1309+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume waiting DONE, task = %p, complete with = %p, status = %s",
1310+
waitingTask, completedTask, statusString().c_str());
13041311

13051312
// Remove the child from the task group's running tasks list.
13061313
// The parent task isn't currently running (we're about to wake
@@ -1652,11 +1659,9 @@ static void swift_taskGroup_waitAllImpl(
16521659
ThrowingTaskFutureWaitContinuationFunction *resumeFunction,
16531660
AsyncContext *rawContext) {
16541661
auto waitingTask = swift_task_getCurrent();
1655-
waitingTask->ResumeTask = task_group_wait_resume_adapter;
1656-
waitingTask->ResumeContext = rawContext;
16571662

16581663
auto group = asBaseImpl(_group);
1659-
PollResult polled = group->waitAll(bodyError, waitingTask);
1664+
PollResult polled = group->waitAll(bodyError, waitingTask, rawContext);
16601665

16611666
auto context = static_cast<TaskFutureWaitAsyncContext *>(rawContext);
16621667
context->ResumeParent =
@@ -1669,19 +1674,17 @@ static void swift_taskGroup_waitAllImpl(
16691674
waitingTask, bodyError, group->statusString().c_str(), to_string(polled.status).c_str());
16701675

16711676
switch (polled.status) {
1672-
case PollStatus::MustWait:
1673-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl MustWait, pending tasks exist, waiting task = %p",
1674-
waitingTask);
1677+
case PollStatus::MustWait: {
16751678
// The waiting task has been queued on the channel,
16761679
// there were pending tasks so it will be woken up eventually.
16771680
#ifdef __ARM_ARCH_7K__
1678-
return workaround_function_swift_taskGroup_waitAllImpl(
1681+
workaround_function_swift_taskGroup_waitAllImpl(
16791682
resultPointer, callerContext, _group, bodyError, resumeFunction, rawContext);
1680-
#else /* __ARM_ARCH_7K__ */
1681-
return;
16821683
#endif /* __ARM_ARCH_7K__ */
1684+
return;
1685+
}
16831686

1684-
case PollStatus::Error:
1687+
case PollStatus::Error: {
16851688
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl Error, waiting task = %p, body error = %p, status:%s",
16861689
waitingTask, bodyError, group->statusString().c_str());
16871690
#if SWIFT_TASK_GROUP_BODY_THROWN_ERROR_WINS
@@ -1702,9 +1705,10 @@ static void swift_taskGroup_waitAllImpl(
17021705
}
17031706

17041707
return waitingTask->runInFullyEstablishedContext();
1708+
}
17051709

17061710
case PollStatus::Empty:
1707-
case PollStatus::Success:
1711+
case PollStatus::Success: {
17081712
/// Anything else than a "MustWait" can be treated as a successful poll.
17091713
/// Only if there are in flight pending tasks do we need to wait after all.
17101714
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl %s, waiting task = %p, status:%s",
@@ -1719,14 +1723,17 @@ static void swift_taskGroup_waitAllImpl(
17191723
}
17201724

17211725
return waitingTask->runInFullyEstablishedContext();
1726+
}
17221727
}
17231728
}
17241729

1725-
/// Must be called while holding the `taskGroup.lock`!
1726-
/// This is because the discarding task group still has some follow-up operations that must
1727-
/// be performed atomically after this operation sometimes, so we cannot unlock inside `waitAll` itself.
1728-
PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask) {
1729-
lock(); // TODO: remove group lock, and use status for synchronization
1730+
PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask, AsyncContext *rawContext) {
1731+
lock();
1732+
1733+
// must mutate the waiting task while holding the group lock,
1734+
// so we don't get an offer concurrently trying to do so
1735+
waitingTask->ResumeTask = task_group_wait_resume_adapter;
1736+
waitingTask->ResumeContext = rawContext;
17301737

17311738
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, bodyError = %p, status = %s", bodyError, statusString().c_str());
17321739
PollResult result = PollResult::getEmpty(this->successType);
@@ -1739,7 +1746,11 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17391746
bool haveRunOneChildTaskInline = false;
17401747

17411748
reevaluate_if_TaskGroup_has_results:;
1742-
auto assumed = statusMarkWaitingAssumeAcquire();
1749+
// Paired with a release when marking Waiting,
1750+
// otherwise we don't modify the status
1751+
auto assumed = statusLoadAcquire();
1752+
1753+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, status = %s", assumed.to_string(this).c_str());
17431754

17441755
// ==== 1) may be able to bail out early if no tasks are pending -------------
17451756
if (assumed.isEmpty(this)) {
@@ -1757,7 +1768,6 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17571768
result.status = PollStatus::Error;
17581769
}
17591770
} // else, we're definitely Empty
1760-
17611771
unlock();
17621772
return result;
17631773
}
@@ -1766,7 +1776,6 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17661776
// No tasks in flight, we know no tasks were submitted before this poll
17671777
// was issued, and if we parked here we'd potentially never be woken up.
17681778
// Bail out and return `nil` from `group.next()`.
1769-
statusRemoveWaitingRelease();
17701779
unlock();
17711780
return result;
17721781
}
@@ -1794,7 +1803,9 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17941803
waitHead, waitingTask,
17951804
/*success*/ std::memory_order_release,
17961805
/*failure*/ std::memory_order_acquire)) {
1797-
unlock(); // TODO: remove fragment lock, and use status for synchronization
1806+
statusMarkWaitingAssumeRelease();
1807+
unlock();
1808+
17981809
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
17991810
// The logic here is paired with the logic in TaskGroupBase::offer. Once
18001811
// we run the

0 commit comments

Comments
 (0)