Skip to content

Commit 6c1ed99

Browse files
committed
Avoid use-after-free with stdexec::run_loop
We need to synchronize returning from `__run_loop_base::run` with potentially concurrent calls to `__run_loop_base::finish`. This is done by introducing a counter, ensuring proper completion of all tasks in flight. Also see NVIDIA#1742 for additional information.
1 parent 4851608 commit 6c1ed99

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

include/stdexec/__detail/__run_loop.hpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
#include "__schedulers.hpp"
2727

2828
#include "__atomic.hpp"
29+
#include "stdexec/__detail/__config.hpp"
30+
#include <atomic>
31+
#include <cstddef>
2932

3033
namespace stdexec {
3134
/////////////////////////////////////////////////////////////////////////////
@@ -34,24 +37,35 @@ namespace stdexec {
3437
public:
3538
__run_loop_base() = default;
3639

40+
~__run_loop_base() noexcept {
41+
STDEXEC_ASSERT(__task_count_.load(__std::memory_order_acquire) == 0);
42+
}
43+
3744
STDEXEC_ATTRIBUTE(host, device) void run() noexcept {
3845
// execute work items until the __finishing_ flag is set:
3946
while (!__finishing_.load(__std::memory_order_acquire)) {
4047
__queue_.wait_for_item();
4148
__execute_all();
4249
}
4350
// drain the queue, taking care to execute any tasks that get added while
44-
// executing the remaining tasks:
45-
while (__execute_all())
51+
// executing the remaining tasks (also wait for other tasks that might still be in flight):
52+
while (__execute_all() || __task_count_.load(__std::memory_order_acquire) > 0)
4653
;
4754
}
4855

4956
STDEXEC_ATTRIBUTE(host, device) void finish() noexcept {
57+
// Increment our task count to avoid lifetime issues. This is preventing
58+
// a use-after-free issue if finish is called from a different thread.
59+
__task_count_.fetch_add(1, __std::memory_order_release);
5060
if (!__finishing_.exchange(true, __std::memory_order_acq_rel)) {
5161
// push an empty work item to the queue to wake up the consuming thread
52-
// and let it finish:
62+
// and let it finish.
63+
// The count will be decremented once the tasks executes.
5364
__queue_.push(&__noop_task);
65+
return;
5466
}
67+
// We are done finishing. Decrement the count, which signals final completion.
68+
__task_count_.fetch_sub(1, __std::memory_order_release);
5569
}
5670

5771
struct __task : __immovable {
@@ -73,6 +87,7 @@ namespace stdexec {
7387

7488
template <class _Rcvr>
7589
struct __opstate_t : __task {
90+
__std::atomic<std::size_t>* __task_count_;
7691
__atomic_intrusive_queue<&__task::__next_>* __queue_;
7792
_Rcvr __rcvr_;
7893

@@ -89,14 +104,17 @@ namespace stdexec {
89104

90105
STDEXEC_ATTRIBUTE(host, device)
91106
constexpr explicit __opstate_t(
107+
__std::atomic<std::size_t>* __task_count,
92108
__atomic_intrusive_queue<&__task::__next_>* __queue,
93109
_Rcvr __rcvr)
94110
: __task{&__execute_impl}
111+
, __task_count_(__task_count)
95112
, __queue_{__queue}
96113
, __rcvr_{static_cast<_Rcvr&&>(__rcvr)} {
97114
}
98115

99116
STDEXEC_ATTRIBUTE(host, device) constexpr void start() noexcept {
117+
__task_count_->fetch_add(1, __std::memory_order_release);
100118
__queue_->push(this);
101119
}
102120
};
@@ -112,20 +130,25 @@ namespace stdexec {
112130
return false; // No tasks to execute.
113131
}
114132

133+
std::size_t __task_count = 0;
134+
115135
do {
116136
// Take care to increment the iterator before executing the task,
117137
// because __execute() may invalidate the current node.
118138
auto __prev = __it++;
119139
(*__prev)->__execute();
140+
++__task_count;
120141
} while (__it != __queue.end());
121142

122143
__queue.clear();
144+
__task_count_.fetch_sub(__task_count, __std::memory_order_release);
123145
return true;
124146
}
125147

126148
STDEXEC_ATTRIBUTE(host, device) static void __noop_(__task*) noexcept {
127149
}
128150

151+
__std::atomic<std::size_t> __task_count_{0};
129152
__std::atomic<bool> __finishing_{false};
130153
__atomic_intrusive_queue<&__task::__next_> __queue_{};
131154
__task __noop_task{&__noop_};
@@ -186,7 +209,7 @@ namespace stdexec {
186209
template <class _Rcvr>
187210
STDEXEC_ATTRIBUTE(nodiscard, host, device)
188211
constexpr auto connect(_Rcvr __rcvr) const noexcept -> __opstate_t<_Rcvr> {
189-
return __opstate_t<_Rcvr>{&__loop_->__queue_, static_cast<_Rcvr&&>(__rcvr)};
212+
return __opstate_t<_Rcvr>{&__loop_->__task_count_, &__loop_->__queue_, static_cast<_Rcvr&&>(__rcvr)};
190213
}
191214

192215
STDEXEC_ATTRIBUTE(nodiscard, host, device)

0 commit comments

Comments
 (0)