Skip to content

Commit 5fe7bdb

Browse files
authored
Merge pull request NVIDIA#1460 from NVIDIA/fix-split-data-race
fix long-standing race condition in `split`
2 parents c211de1 + 8029a5c commit 5fe7bdb

File tree

3 files changed

+106
-85
lines changed

3 files changed

+106
-85
lines changed

include/stdexec/__detail/__ensure_started.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,13 @@ namespace stdexec {
7979
static_cast<_Sender&&>(__sndr),
8080
[&]<class _Env, class _Child>(__ignore, _Env&& __env, _Child&& __child) {
8181
// The shared state starts life with a ref-count of one.
82-
auto __sh_state = __make_intrusive<__shared_state<_Child, __decay_t<_Env>>, 2>(
83-
static_cast<_Child&&>(__child), static_cast<_Env&&>(__env));
82+
auto* __sh_state =
83+
new __shared_state{static_cast<_Child&&>(__child), static_cast<_Env&&>(__env)};
8484

8585
// Eagerly start the work:
86-
__sh_state->__try_start();
86+
__sh_state->__try_start(); // cannot throw
8787

88-
return __make_sexpr<__ensure_started_t>(
89-
__box{__ensure_started_t(), std::move(__sh_state)});
88+
return __make_sexpr<__ensure_started_t>(__box{__ensure_started_t(), __sh_state});
9089
});
9190
}
9291
};

include/stdexec/__detail/__shared.hpp

Lines changed: 99 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "__basic_sender.hpp"
2222
#include "__cpo.hpp"
2323
#include "__env.hpp"
24-
#include "__intrusive_ptr.hpp"
2524
#include "__intrusive_slist.hpp"
2625
#include "__optional.hpp"
2726
#include "__meta.hpp"
@@ -32,8 +31,11 @@
3231
#include "../stop_token.hpp"
3332
#include "../functional.hpp"
3433

34+
#include <atomic>
3535
#include <exception>
3636
#include <mutex>
37+
#include <type_traits>
38+
#include <utility>
3739

3840
namespace stdexec {
3941
////////////////////////////////////////////////////////////////////////////
@@ -102,7 +104,9 @@ namespace stdexec {
102104
}
103105

104106
~__local_state() {
105-
__sh_state_t::__detach(__sh_state_);
107+
if (__sh_state_) {
108+
__sh_state_->__detach();
109+
}
106110
}
107111

108112
// Stop request callback:
@@ -126,13 +130,14 @@ namespace stdexec {
126130
// __notify function is called from the shared state's __notify_waiters function, which
127131
// first sets __waiters_ to the completed state. As a result, the attempt to remove `this`
128132
// from the waiters list above will fail and this stop request is ignored.
129-
__sh_state_t::__detach(__sh_state_);
133+
std::exchange(__sh_state_, nullptr)->__detach();
130134
stdexec::set_stopped(static_cast<_Receiver&&>(this->__receiver()));
131135
}
132136

133137
// This is called from __shared_state::__notify_waiters when the input async operation
134138
// completes; or, if it has already completed when start is called, it is called from start:
135-
// __notify cannot race with __on_stop_request. See comment in __on_stop_request.
139+
// __notify cannot race with __local_state::operator(). See comment in
140+
// __local_state::operator().
136141
template <class _Tag>
137142
static void __notify(__local_state_base* __base) noexcept {
138143
auto* const __self = static_cast<__local_state*>(__base);
@@ -150,11 +155,12 @@ namespace stdexec {
150155
}
151156

152157
static auto __get_sh_state(_CvrefSender& __sndr) noexcept {
153-
return __sndr.apply(static_cast<_CvrefSender&&>(__sndr), __detail::__get_data()).__sh_state_;
158+
auto __box = __sndr.apply(static_cast<_CvrefSender&&>(__sndr), __detail::__get_data());
159+
return std::exchange(__box.__sh_state_, nullptr);
154160
}
155161

156162
using __sh_state_ptr_t = __result_of<__get_sh_state, _CvrefSender&>;
157-
using __sh_state_t = typename __sh_state_ptr_t::element_type;
163+
using __sh_state_t = std::remove_pointer_t<__sh_state_ptr_t>;
158164

159165
__optional<stop_callback_for_t<__stok_t, __local_state&>> __on_stop_{};
160166
__sh_state_ptr_t __sh_state_;
@@ -193,14 +199,13 @@ namespace stdexec {
193199
};
194200

195201
inline __local_state_base* __get_tombstone() noexcept {
196-
static __local_state_base __tombstone_{{}, nullptr, nullptr};
202+
static constinit __local_state_base __tombstone_{{}, nullptr, nullptr};
197203
return &__tombstone_;
198204
}
199205

200206
//! Heap-allocatable shared state for things like `stdexec::split`.
201207
template <class _CvrefSender, class _Env>
202-
struct __shared_state
203-
: private __enable_intrusive_from_this<__shared_state<_CvrefSender, _Env>, 2> {
208+
struct __shared_state {
204209
using __receiver_t = __t<__receiver<__cvref_id<_CvrefSender>, __id<_Env>>>;
205210
using __waiters_list_t = __intrusive_slist<&__local_state_base::__next_>;
206211

@@ -213,70 +218,82 @@ namespace stdexec {
213218
__munique<__mbind_front_q<__variant_for, __tuple_for<set_stopped_t>>>::__f,
214219
__tuple_for<set_error_t, std::exception_ptr>>;
215220

216-
static constexpr std::size_t __started_bit = 0;
217-
static constexpr std::size_t __completed_bit = 1;
218-
219221
inplace_stop_source __stop_source_{};
220222
__env_t<_Env> __env_;
221223
__variant_t __results_{}; // Defaults to the "set_stopped" state
222224
std::mutex __mutex_; // This mutex guards access to __waiters_.
223225
__waiters_list_t __waiters_{};
224226
connect_result_t<_CvrefSender, __receiver_t> __shared_op_;
227+
std::atomic_flag __started_{};
228+
std::atomic<std::size_t> __ref_count_{2};
229+
230+
// Let a "consumer" be either a split/ensure_started sender, or an operation
231+
// state created by connecting a split/ensure_started sender to a receiver.
232+
// Let is_running be 1 if the shared operation is currently executing (after
233+
// start has been called but before the receiver's completion functions have
234+
// executed), and 0 otherwise. Then __ref_count_ is equal to:
235+
//
236+
// (2 * (nbr of consumers)) + is_running
225237

226238
explicit __shared_state(_CvrefSender&& __sndr, _Env __env)
227239
: __env_(
228240
__env::__join(
229241
prop{get_stop_token, __stop_source_.get_token()},
230242
static_cast<_Env&&>(__env)))
231243
, __shared_op_(connect(static_cast<_CvrefSender&&>(__sndr), __receiver_t{this})) {
232-
// add one ref count to account for the case where there are no watchers left but the
233-
// shared op is still running.
234-
this->__inc_ref();
235244
}
236245

237-
// The caller of this wants to release their reference to the shared state. The ref
238-
// count must be at least 2 at this point: one owned by the caller, and one added in the
239-
// __shared_state ctor.
240-
static void __detach(__intrusive_ptr<__shared_state, 2>& __ptr) noexcept {
241-
// Ask the intrusive ptr to stop managing the reference count so we can manage it manually.
242-
if (auto* __self = __ptr.__release_()) {
243-
auto __old = __self->__dec_ref();
244-
STDEXEC_ASSERT(__count(__old) >= 2);
245-
246-
if (__count(__old) == 2) {
247-
// The last watcher has released its reference. Asked the shared op to stop.
248-
static_cast<__shared_state*>(__self)->__stop_source_.request_stop();
249-
250-
// Additionally, if the shared op was never started, or if it has already completed,
251-
// then the shared state is no longer needed. Decrement the ref count to 0 here, which
252-
// will delete __self.
253-
if (!__bit<__started_bit>(__old) || __bit<__completed_bit>(__old)) {
254-
__self->__dec_ref();
255-
}
256-
}
246+
void __inc_ref() noexcept {
247+
__ref_count_.fetch_add(2ul, std::memory_order_relaxed);
248+
}
249+
250+
void __dec_ref() noexcept {
251+
if (2ul == __ref_count_.fetch_sub(2ul, std::memory_order_acq_rel)) {
252+
delete this;
257253
}
258254
}
259255

260-
/// @post The started bit is set in the shared state's ref count, OR the __waiters_ list
256+
bool __set_started() noexcept {
257+
if (__started_.test_and_set(std::memory_order_acq_rel)) {
258+
return false; // already started
259+
}
260+
__ref_count_.fetch_add(1ul, std::memory_order_relaxed);
261+
return true;
262+
}
263+
264+
void __set_completed() noexcept {
265+
if (1ul == __ref_count_.fetch_sub(1ul, std::memory_order_acq_rel)) {
266+
delete this;
267+
}
268+
}
269+
270+
void __detach() noexcept {
271+
if (__ref_count_.load() < 4ul) {
272+
// We are the final "consumer", and we are about to release our reference
273+
// to the shared state. Ask the operation to stop early.
274+
__stop_source_.request_stop();
275+
}
276+
__dec_ref();
277+
}
278+
279+
/// @post The "is running" bit is set in the shared state's ref count, OR the __waiters_ list
261280
/// is set to the known "tombstone" value indicating completion.
262281
void __try_start() noexcept {
263282
// With the split algorithm, multiple split senders can be started simultaneously, but
264-
// only one should start the shared async operation. If the "started" bit is set, then
283+
// only one should start the shared async operation. If the low bit is set, then
265284
// someone else has already started the shared operation. Do nothing.
266-
if (this->template __is_set<__started_bit>()) {
267-
return;
268-
} else if (__bit<__started_bit>(this->template __set_bit<__started_bit>())) {
269-
return;
270-
} else if (__stop_source_.stop_requested()) {
271-
// Stop has already been requested. Rather than starting the operation, complete with
272-
// set_stopped immediately.
273-
// 1. Sets __waiters_ to a known "tombstone" value
274-
// 2. Notifies all the waiters that the operation has stopped
275-
// 3. Sets the "completed" bit in the ref count.
276-
__notify_waiters();
277-
return;
278-
} else {
279-
stdexec::start(__shared_op_);
285+
if (__set_started()) {
286+
// we are the first to start the underlying operation
287+
if (__stop_source_.stop_requested()) {
288+
// Stop has already been requested. Rather than starting the operation, complete with
289+
// set_stopped immediately.
290+
// 1. Sets __waiters_ to a known "tombstone" value.
291+
// 2. Notifies all the waiters that the operation has stopped.
292+
// 3. Sets the "is running" bit in the ref count to 0.
293+
__notify_waiters();
294+
} else {
295+
stdexec::start(__shared_op_);
296+
}
280297
}
281298
}
282299

@@ -328,22 +345,22 @@ namespace stdexec {
328345
for (auto __itr = __waiters_copy.begin(); __itr != __waiters_copy.end();) {
329346
__local_state_base* __item = *__itr;
330347

331-
// We must increment the iterator before calling notify, since notify
332-
// may end up triggering *__item to be destructed on another thread,
333-
// and the intrusive slist's iterator increment relies on __item.
348+
// We must increment the iterator before calling notify, since notify may end up
349+
// triggering *__item to be destructed on another thread, and the intrusive slist's
350+
// iterator increment relies on __item.
334351
++__itr;
335-
336352
__item->__notify_(__item);
337353
}
338354

339-
// Set the "completed" bit in the ref count. If the ref count is 1, then there are no more
340-
// waiters. Release the final reference.
341-
if (__count(this->template __set_bit<__completed_bit>()) == 1) {
342-
this->__dec_ref(); // release the extra ref count, deletes this
343-
}
355+
// Set the "is running" bit in the ref count to zero. Delete the shared state if the
356+
// ref-count is now zero.
357+
__set_completed();
344358
}
345359
};
346360

361+
template <class _CvrefSender, class _Env>
362+
__shared_state(_CvrefSender&&, _Env) -> __shared_state<_CvrefSender, _Env>;
363+
347364
template <class _Cvref, class _CvrefSender, class _Env>
348365
using __make_completions = //
349366
__try_make_completion_signatures<
@@ -374,30 +391,36 @@ namespace stdexec {
374391
using __tag_t = __if_c<_Copyable, __split::__split_t, __ensure_started::__ensure_started_t>;
375392
using __sh_state_t = __shared_state<_CvrefSender, _Env>;
376393

377-
__box(__tag_t, __intrusive_ptr<__sh_state_t, 2> __sh_state) noexcept
378-
: __sh_state_(std::move(__sh_state)) {
394+
__box(__tag_t, __sh_state_t* __sh_state) noexcept
395+
: __sh_state_(__sh_state) {
396+
}
397+
398+
__box(__box&& __other) noexcept
399+
: __sh_state_(std::exchange(__other.__sh_state_, nullptr)) {
379400
}
380401

381-
__box(__box&&) noexcept = default;
382-
__box(const __box&) noexcept
402+
__box(const __box& __other) noexcept
383403
requires _Copyable
384-
= default;
404+
: __sh_state_(__other.__sh_state_) {
405+
__sh_state_->__inc_ref();
406+
}
385407

386408
~__box() {
387-
__sh_state_t::__detach(__sh_state_);
409+
if (__sh_state_) {
410+
__sh_state_->__detach();
411+
}
388412
}
389413

390-
__intrusive_ptr<__sh_state_t, 2> __sh_state_;
414+
__sh_state_t* __sh_state_;
391415
};
392416

393417
template <class _CvrefSender, class _Env>
394-
__box(__split::__split_t, __intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2>) //
418+
__box(__split::__split_t, __shared_state<_CvrefSender, _Env>*) //
395419
->__box<_CvrefSender, _Env, true>;
396420

397421
template <class _CvrefSender, class _Env>
398-
__box(
399-
__ensure_started::__ensure_started_t,
400-
__intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2>) -> __box<_CvrefSender, _Env, false>;
422+
__box(__ensure_started::__ensure_started_t, __shared_state<_CvrefSender, _Env>*)
423+
-> __box<_CvrefSender, _Env, false>;
401424

402425
template <class _Tag>
403426
struct __shared_impl : __sexpr_defaults {
@@ -419,14 +442,13 @@ namespace stdexec {
419442
[]<class _Sender, class _Receiver>(
420443
__local_state<_Sender, _Receiver>& __self,
421444
_Receiver& __rcvr) noexcept -> void {
422-
using __sh_state_t = typename __local_state<_Sender, _Receiver>::__sh_state_t;
423445
// Scenario: there are no more split senders, this is the only operation state, the
424446
// underlying operation has not yet been started, and the receiver's stop token is already
425447
// in the "stop requested" state. Then registering the stop callback will call
426-
// __on_stop_request on __self synchronously. It may also be called asynchronously at
427-
// any point after the callback is registered. Beware. We are guaranteed, however, that
428-
// __on_stop_request will not complete the operation or decrement the shared state's ref
429-
// count until after __self has been added to the waiters list.
448+
// __local_state::operator() on __self synchronously. It may also be called asynchronously
449+
// at any point after the callback is registered. Beware. We are guaranteed, however, that
450+
// __local_state::operator() will not complete the operation or decrement the shared state's
451+
// ref count until after __self has been added to the waiters list.
430452
const auto __stok = stdexec::get_stop_token(stdexec::get_env(__rcvr));
431453
__self.__on_stop_.emplace(__stok, __self);
432454

@@ -446,7 +468,7 @@ namespace stdexec {
446468
// Otherwise, failed to add the waiter because of a stop-request.
447469
// Complete synchronously with set_stopped().
448470
__self.__on_stop_.reset();
449-
__sh_state_t::__detach(__self.__sh_state_);
471+
std::exchange(__self.__sh_state_, nullptr)->__detach();
450472
stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
451473
};
452474
};

include/stdexec/__detail/__split.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ namespace stdexec {
7373
static_cast<_Sender&&>(__sndr),
7474
[&]<class _Env, class _Child>(__ignore, _Env&& __env, _Child&& __child) {
7575
// The shared state starts life with a ref-count of one.
76-
auto __sh_state = __make_intrusive<__shared_state<_Child, __decay_t<_Env>>, 2>(
77-
static_cast<_Child&&>(__child), static_cast<_Env&&>(__env));
76+
auto* __sh_state =
77+
new __shared_state{static_cast<_Child&&>(__child), static_cast<_Env&&>(__env)};
7878

79-
return __make_sexpr<__split_t>(__box{__split_t(), std::move(__sh_state)});
79+
return __make_sexpr<__split_t>(__box{__split_t(), __sh_state});
8080
});
8181
}
8282
};

0 commit comments

Comments
 (0)