Skip to content

Commit a25aa2c

Browse files
committed
fix bug in resuming tasks in offer/resume task in discarding group
1 parent 61a1835 commit a25aa2c

File tree

1 file changed

+48
-37
lines changed

1 file changed

+48
-37
lines changed

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ class DiscardingTaskGroup: public TaskGroupBase {
767767

768768
private:
769769
/// Resume waiting task with specified error
770-
void resumeWaitingTaskWithError(SwiftError *error, TaskGroupStatus &assumed);
770+
void resumeWaitingTaskWithError(SwiftError *error, TaskGroupStatus &assumed, bool alreadyDecremented);
771771
};
772772

773773
} // end anonymous namespace
@@ -1179,6 +1179,7 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context)
11791179

11801180
auto afterComplete = statusCompletePendingAssumeRelease();
11811181
(void) afterComplete;
1182+
bool alreadyDecrementedStatus = true;
11821183
SWIFT_TASK_GROUP_DEBUG_LOG(this, "offer, complete, status:%s",
11831184
afterComplete.to_string(this).c_str());
11841185

@@ -1196,18 +1197,18 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context)
11961197
waitQueue.load(std::memory_order_relaxed));
11971198
switch (priorErrorItem.getStatus()) {
11981199
case ReadyStatus::RawError:
1199-
resumeWaitingTaskWithError(priorErrorItem.getRawError(this), assumed);
1200+
resumeWaitingTaskWithError(priorErrorItem.getRawError(this), assumed, alreadyDecrementedStatus);
12001201
break;
12011202
case ReadyStatus::Error:
1202-
resumeWaitingTask(priorErrorItem.getTask(), assumed, /*hadErrorResult=*/true, /*alreadyDecremented=*/true);
1203+
resumeWaitingTask(priorErrorItem.getTask(), assumed, /*hadErrorResult=*/true, alreadyDecrementedStatus);
12031204
break;
12041205
default:
12051206
swift_Concurrency_fatalError(0, "only errors can be stored by a discarding task group, yet it wasn't an error!");
12061207
}
12071208
} else {
12081209
SWIFT_TASK_GROUP_DEBUG_LOG(this, "offer, last pending task, completing with completedTask:%p, completedTask.error:%d, waitingTask:%p",
12091210
completedTask, hadErrorResult, waitQueue.load(std::memory_order_relaxed));
1210-
resumeWaitingTask(completedTask, assumed, /*hadErrorResult=*/hadErrorResult);
1211+
resumeWaitingTask(completedTask, assumed, /*hadErrorResult=*/hadErrorResult, alreadyDecrementedStatus);
12111212
}
12121213
}
12131214

@@ -1224,44 +1225,48 @@ void TaskGroupBase::resumeWaitingTask(
12241225
auto waitingTask = waitQueue.load(std::memory_order_acquire);
12251226
assert(waitingTask && "waitingTask must not be null when attempting to resume it");
12261227
assert(assumed.hasWaitingTask());
1227-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume waiting task = %p, error:%d, complete with = %p",
1228-
waitingTask, hadErrorResult, completedTask);
1228+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume waiting task = %p, alreadyDecremented:%d, error:%d, complete with = %p",
1229+
waitingTask, alreadyDecremented, hadErrorResult, completedTask);
12291230
while (true) {
1231+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resumeWaitingTask, attempt CAS, waiting task = %p, waitQueue.head = %p, error:%d, complete with = %p",
1232+
waitingTask, waitQueue.load(std::memory_order_relaxed), hadErrorResult, completedTask);
1233+
12301234
// ==== a) run waiting task directly -------------------------------------
1231-
// assert(assumed.pendingTasks(this) && "offered to group with no pending tasks!");
1232-
// We are the "first" completed task to arrive,
1233-
// and since there is a task waiting we immediately claim and complete it.
1234-
if (waitQueue.compare_exchange_strong(
1235-
waitingTask, nullptr,
1236-
/*success*/ std::memory_order_seq_cst,
1237-
/*failure*/ std::memory_order_seq_cst)) {
1235+
// assert(assumed.pendingTasks(this) && "offered to group with no pending tasks!");
1236+
// We are the "first" completed task to arrive,
1237+
// and since there is a task waiting we immediately claim and complete it.
1238+
if (waitQueue.compare_exchange_strong(
1239+
waitingTask, nullptr,
1240+
/*success*/ std::memory_order_release,
1241+
/*failure*/ std::memory_order_acquire)) {
12381242

12391243
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
1240-
// In the task-to-thread model, child tasks are always actually
1241-
// run synchronously on the parent task's thread. For task groups
1242-
// specifically, this means that poll() will pick a child task
1243-
// that was added to the group and run it to completion as a
1244-
// subroutine. Therefore, when we enter offer(), we know that
1245-
// the parent task is waiting and we can just return to it.
1246-
1247-
// The task-to-thread logic in poll() currently expects the child
1248-
// task to enqueue itself instead of just filling in the result in
1249-
// the waiting task. This is a little wasteful; there's no reason
1250-
// we can't just have the parent task set itself up as a waiter.
1251-
// But since it's what we're doing, we basically take the same
1252-
// path as we would if there wasn't a waiter.
1253-
enqueueCompletedTask(completedTask, hadErrorResult);
1254-
return;
1244+
// In the task-to-thread model, child tasks are always actually
1245+
// run synchronously on the parent task's thread. For task groups
1246+
// specifically, this means that poll() will pick a child task
1247+
// that was added to the group and run it to completion as a
1248+
// subroutine. Therefore, when we enter offer(), we know that
1249+
// the parent task is waiting and we can just return to it.
1250+
1251+
// The task-to-thread logic in poll() currently expects the child
1252+
// task to enqueue itself instead of just filling in the result in
1253+
// the waiting task. This is a little wasteful; there's no reason
1254+
// we can't just have the parent task set itself up as a waiter.
1255+
// But since it's what we're doing, we basically take the same
1256+
// path as we would if there wasn't a waiter.
1257+
enqueueCompletedTask(completedTask, hadErrorResult);
1258+
return;
12551259

12561260
#else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */
1257-
fprintf(stderr, "[%s:%d](%s) assumed:%s\n", __FILE_NAME__, __LINE__, __FUNCTION__, assumed.to_string(this).c_str());
1258-
fprintf(stderr, "[%s:%d](%s) had:%s\n", __FILE_NAME__, __LINE__, __FUNCTION__, this->statusString().c_str());
1261+
if (!alreadyDecremented) {
1262+
(void) statusCompletePendingReadyWaiting(assumed);
1263+
}
12591264

1260-
if (alreadyDecremented || statusCompletePendingReadyWaiting(assumed)) {
12611265
// Run the task.
12621266
auto result = PollResult::get(completedTask, hadErrorResult);
1263-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume waiting DONE, task = %p, complete with = %p, status = %s",
1264-
waitingTask, completedTask, statusString().c_str());
1267+
SWIFT_TASK_GROUP_DEBUG_LOG(this,
1268+
"resume waiting DONE, task = %p, backup = %p, complete with = %p, status = %s",
1269+
waitingTask, backup, completedTask, statusString().c_str());
12651270

12661271
// Remove the child from the task group's running tasks list.
12671272
// The parent task isn't currently running (we're about to wake
@@ -1283,8 +1288,10 @@ void TaskGroupBase::resumeWaitingTask(
12831288
// TODO: allow the caller to suggest an executor
12841289
waitingTask->flagAsAndEnqueueOnExecutor(ExecutorRef::generic());
12851290
return;
1286-
} // else, try again
1287-
#endif
1291+
#endif /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */
1292+
} else {
1293+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "CAS failed, task = %p, backup = %p, complete with = %p, status = %s",
1294+
waitingTask, backup, completedTask, statusString().c_str());
12881295
}
12891296
}
12901297
swift_Concurrency_fatalError(0, "should have enqueued and returned.");
@@ -1293,7 +1300,8 @@ void TaskGroupBase::resumeWaitingTask(
12931300
/// Must be called while holding the TaskGroup lock.
12941301
void DiscardingTaskGroup::resumeWaitingTaskWithError(
12951302
SwiftError *error,
1296-
TaskGroupStatus &assumed) {
1303+
TaskGroupStatus &assumed,
1304+
bool alreadyDecremented) {
12971305
auto waitingTask = waitQueue.load(std::memory_order_acquire);
12981306
assert(waitingTask && "cannot resume 'null' waiting task!");
12991307
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume waiting task = %p, with error = %p",
@@ -1327,7 +1335,7 @@ void DiscardingTaskGroup::resumeWaitingTaskWithError(
13271335
return;
13281336

13291337
#else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */
1330-
if (statusCompletePendingReadyWaiting(assumed)) {
1338+
if (alreadyDecremented || statusCompletePendingReadyWaiting(assumed)) {
13311339
// Run the task.
13321340
auto result = PollResult::getError(error);
13331341

@@ -1567,6 +1575,7 @@ reevaluate_if_taskgroup_has_results:;
15671575
waitingTask->flagAsSuspended();
15681576
}
15691577
// Put the waiting task at the beginning of the wait queue.
1578+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "WATCH OUT, SET WAITER ONTO waitQueue.head = %p", waitQueue.load(std::memory_order_relaxed));
15701579
if (waitQueue.compare_exchange_strong(
15711580
waitHead, waitingTask,
15721581
/*success*/ std::memory_order_release,
@@ -1753,6 +1762,8 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask,
17531762
waitingTask->flagAsSuspended();
17541763
}
17551764
// Put the waiting task at the beginning of the wait queue.
1765+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "WATCH OUT, set waiter onto... waitQueue.head = %p", waitQueue.load(std::memory_order_relaxed));
1766+
17561767
if (waitQueue.compare_exchange_strong(
17571768
waitHead, waitingTask,
17581769
/*success*/ std::memory_order_release,

0 commit comments

Comments
 (0)