2121#include " __basic_sender.hpp"
2222#include " __cpo.hpp"
2323#include " __env.hpp"
24- #include " __intrusive_ptr.hpp"
2524#include " __intrusive_slist.hpp"
2625#include " __optional.hpp"
2726#include " __meta.hpp"
3231#include " ../stop_token.hpp"
3332#include " ../functional.hpp"
3433
34+ #include < atomic>
3535#include < exception>
3636#include < mutex>
37+ #include < type_traits>
38+ #include < utility>
3739
3840namespace stdexec {
3941 // //////////////////////////////////////////////////////////////////////////
@@ -102,7 +104,9 @@ namespace stdexec {
102104 }
103105
104106 ~__local_state () {
105- __sh_state_t::__detach (__sh_state_);
107+ if (__sh_state_) {
108+ __sh_state_->__detach ();
109+ }
106110 }
107111
108112 // Stop request callback:
@@ -126,13 +130,14 @@ namespace stdexec {
126130 // __notify function is called from the shared state's __notify_waiters function, which
127131 // first sets __waiters_ to the completed state. As a result, the attempt to remove `this`
128132 // from the waiters list above will fail and this stop request is ignored.
129- __sh_state_t::__detach (__sh_state_);
133+ std::exchange (__sh_state_, nullptr )-> __detach ( );
130134 stdexec::set_stopped (static_cast <_Receiver&&>(this ->__receiver ()));
131135 }
132136
133137 // This is called from __shared_state::__notify_waiters when the input async operation
134138 // completes; or, if it has already completed when start is called, it is called from start:
135- // __notify cannot race with __on_stop_request. See comment in __on_stop_request.
139+ // __notify cannot race with __local_state::operator(). See comment in
140+ // __local_state::operator().
136141 template <class _Tag >
137142 static void __notify (__local_state_base* __base) noexcept {
138143 auto * const __self = static_cast <__local_state*>(__base);
@@ -150,11 +155,12 @@ namespace stdexec {
150155 }
151156
152157 static auto __get_sh_state (_CvrefSender& __sndr) noexcept {
153- return __sndr.apply (static_cast <_CvrefSender&&>(__sndr), __detail::__get_data ()).__sh_state_ ;
158+ auto __box = __sndr.apply (static_cast <_CvrefSender&&>(__sndr), __detail::__get_data ());
159+ return std::exchange (__box.__sh_state_ , nullptr );
154160 }
155161
156162 using __sh_state_ptr_t = __result_of<__get_sh_state, _CvrefSender&>;
157- using __sh_state_t = typename __sh_state_ptr_t ::element_type ;
163+ using __sh_state_t = std:: remove_pointer_t < __sh_state_ptr_t > ;
158164
159165 __optional<stop_callback_for_t <__stok_t , __local_state&>> __on_stop_{};
160166 __sh_state_ptr_t __sh_state_;
@@ -193,14 +199,13 @@ namespace stdexec {
193199 };
194200
195201 inline __local_state_base* __get_tombstone () noexcept {
196- static __local_state_base __tombstone_{{}, nullptr , nullptr };
202+ static constinit __local_state_base __tombstone_{{}, nullptr , nullptr };
197203 return &__tombstone_;
198204 }
199205
200206 // ! Heap-allocatable shared state for things like `stdexec::split`.
201207 template <class _CvrefSender , class _Env >
202- struct __shared_state
203- : private __enable_intrusive_from_this<__shared_state<_CvrefSender, _Env>, 2 > {
208+ struct __shared_state {
204209 using __receiver_t = __t <__receiver<__cvref_id<_CvrefSender>, __id<_Env>>>;
205210 using __waiters_list_t = __intrusive_slist<&__local_state_base::__next_>;
206211
@@ -213,70 +218,82 @@ namespace stdexec {
213218 __munique<__mbind_front_q<__variant_for, __tuple_for<set_stopped_t >>>::__f,
214219 __tuple_for<set_error_t , std::exception_ptr>>;
215220
216- static constexpr std::size_t __started_bit = 0 ;
217- static constexpr std::size_t __completed_bit = 1 ;
218-
219221 inplace_stop_source __stop_source_{};
220222 __env_t <_Env> __env_;
221223 __variant_t __results_{}; // Defaults to the "set_stopped" state
222224 std::mutex __mutex_; // This mutex guards access to __waiters_.
223225 __waiters_list_t __waiters_{};
224226 connect_result_t <_CvrefSender, __receiver_t > __shared_op_;
227+ std::atomic_flag __started_{};
228+ std::atomic<std::size_t > __ref_count_{2 };
229+
230+ // Let a "consumer" be either a split/ensure_started sender, or an operation
231+ // state created by connecting a split/ensure_started sender to a receiver.
232+ // Let is_running be 1 if the shared operation is currently executing (after
233+ // start has been called but before the receiver's completion functions have
234+ // executed), and 0 otherwise. Then __ref_count_ is equal to:
235+ //
236+ // (2 * (nbr of consumers)) + is_running
225237
226238 explicit __shared_state (_CvrefSender&& __sndr, _Env __env)
227239 : __env_(
228240 __env::__join (
229241 prop{get_stop_token, __stop_source_.get_token ()},
230242 static_cast <_Env&&>(__env)))
231243 , __shared_op_(connect(static_cast <_CvrefSender&&>(__sndr), __receiver_t{this })) {
232- // add one ref count to account for the case where there are no watchers left but the
233- // shared op is still running.
234- this ->__inc_ref ();
235244 }
236245
237- // The caller of this wants to release their reference to the shared state. The ref
238- // count must be at least 2 at this point: one owned by the caller, and one added in the
239- // __shared_state ctor.
240- static void __detach (__intrusive_ptr<__shared_state, 2 >& __ptr) noexcept {
241- // Ask the intrusive ptr to stop managing the reference count so we can manage it manually.
242- if (auto * __self = __ptr.__release_ ()) {
243- auto __old = __self->__dec_ref ();
244- STDEXEC_ASSERT (__count (__old) >= 2 );
245-
246- if (__count (__old) == 2 ) {
247- // The last watcher has released its reference. Asked the shared op to stop.
248- static_cast <__shared_state*>(__self)->__stop_source_ .request_stop ();
249-
250- // Additionally, if the shared op was never started, or if it has already completed,
251- // then the shared state is no longer needed. Decrement the ref count to 0 here, which
252- // will delete __self.
253- if (!__bit<__started_bit>(__old) || __bit<__completed_bit>(__old)) {
254- __self->__dec_ref ();
255- }
256- }
246+ void __inc_ref () noexcept {
247+ __ref_count_.fetch_add (2ul , std::memory_order_relaxed);
248+ }
249+
250+ void __dec_ref () noexcept {
251+ if (2ul == __ref_count_.fetch_sub (2ul , std::memory_order_acq_rel)) {
252+ delete this ;
257253 }
258254 }
259255
260- // / @post The started bit is set in the shared state's ref count, OR the __waiters_ list
256+ bool __set_started () noexcept {
257+ if (__started_.test_and_set (std::memory_order_acq_rel)) {
258+ return false ; // already started
259+ }
260+ __ref_count_.fetch_add (1ul , std::memory_order_relaxed);
261+ return true ;
262+ }
263+
264+ void __set_completed () noexcept {
265+ if (1ul == __ref_count_.fetch_sub (1ul , std::memory_order_acq_rel)) {
266+ delete this ;
267+ }
268+ }
269+
270+ void __detach () noexcept {
271+ if (__ref_count_.load () < 4ul ) {
272+ // We are the final "consumer", and we are about to release our reference
273+ // to the shared state. Ask the operation to stop early.
274+ __stop_source_.request_stop ();
275+ }
276+ __dec_ref ();
277+ }
278+
279+ // / @post The "is running" bit is set in the shared state's ref count, OR the __waiters_ list
261280 // / is set to the known "tombstone" value indicating completion.
262281 void __try_start () noexcept {
263282 // With the split algorithm, multiple split senders can be started simultaneously, but
264- // only one should start the shared async operation. If the "started" bit is set, then
283+ // only one should start the shared async operation. If the low bit is set, then
265284 // someone else has already started the shared operation. Do nothing.
266- if (this ->template __is_set <__started_bit>()) {
267- return ;
268- } else if (__bit<__started_bit>(this ->template __set_bit <__started_bit>())) {
269- return ;
270- } else if (__stop_source_.stop_requested ()) {
271- // Stop has already been requested. Rather than starting the operation, complete with
272- // set_stopped immediately.
273- // 1. Sets __waiters_ to a known "tombstone" value
274- // 2. Notifies all the waiters that the operation has stopped
275- // 3. Sets the "completed" bit in the ref count.
276- __notify_waiters ();
277- return ;
278- } else {
279- stdexec::start (__shared_op_);
285+ if (__set_started ()) {
286+ // we are the first to start the underlying operation
287+ if (__stop_source_.stop_requested ()) {
288+ // Stop has already been requested. Rather than starting the operation, complete with
289+ // set_stopped immediately.
290+ // 1. Sets __waiters_ to a known "tombstone" value.
291+ // 2. Notifies all the waiters that the operation has stopped.
292+ // 3. Sets the "is running" bit in the ref count to 0.
293+ __notify_waiters ();
294+ } else {
295+ stdexec::start (__shared_op_);
296+ }
280297 }
281298 }
282299
@@ -328,22 +345,22 @@ namespace stdexec {
328345 for (auto __itr = __waiters_copy.begin (); __itr != __waiters_copy.end ();) {
329346 __local_state_base* __item = *__itr;
330347
331- // We must increment the iterator before calling notify, since notify
332- // may end up triggering *__item to be destructed on another thread,
333- // and the intrusive slist's iterator increment relies on __item.
348+ // We must increment the iterator before calling notify, since notify may end up
349+ // triggering *__item to be destructed on another thread, and the intrusive slist's
350+ // iterator increment relies on __item.
334351 ++__itr;
335-
336352 __item->__notify_ (__item);
337353 }
338354
339- // Set the "completed" bit in the ref count. If the ref count is 1, then there are no more
340- // waiters. Release the final reference.
341- if (__count (this ->template __set_bit <__completed_bit>()) == 1 ) {
342- this ->__dec_ref (); // release the extra ref count, deletes this
343- }
355+ // Set the "is running" bit in the ref count to zero. Delete the shared state if the
356+ // ref-count is now zero.
357+ __set_completed ();
344358 }
345359 };
346360
361+ template <class _CvrefSender , class _Env >
362+ __shared_state (_CvrefSender&&, _Env) -> __shared_state<_CvrefSender, _Env>;
363+
347364 template <class _Cvref , class _CvrefSender , class _Env >
348365 using __make_completions = //
349366 __try_make_completion_signatures<
@@ -374,30 +391,36 @@ namespace stdexec {
374391 using __tag_t = __if_c<_Copyable, __split::__split_t , __ensure_started::__ensure_started_t >;
375392 using __sh_state_t = __shared_state<_CvrefSender, _Env>;
376393
377- __box (__tag_t , __intrusive_ptr<__sh_state_t , 2 > __sh_state) noexcept
378- : __sh_state_(std::move(__sh_state)) {
394+ __box (__tag_t , __sh_state_t * __sh_state) noexcept
395+ : __sh_state_(__sh_state) {
396+ }
397+
398+ __box (__box&& __other) noexcept
399+ : __sh_state_(std::exchange(__other.__sh_state_, nullptr )) {
379400 }
380401
381- __box (__box&&) noexcept = default ;
382- __box (const __box&) noexcept
402+ __box (const __box& __other) noexcept
383403 requires _Copyable
384- = default ;
404+ : __sh_state_(__other.__sh_state_) {
405+ __sh_state_->__inc_ref ();
406+ }
385407
386408 ~__box () {
387- __sh_state_t::__detach (__sh_state_);
409+ if (__sh_state_) {
410+ __sh_state_->__detach ();
411+ }
388412 }
389413
390- __intrusive_ptr< __sh_state_t , 2 > __sh_state_;
414+ __sh_state_t * __sh_state_;
391415 };
392416
393417 template <class _CvrefSender , class _Env >
394- __box (__split::__split_t , __intrusive_ptr< __shared_state<_CvrefSender, _Env>, 2 > ) //
418+ __box (__split::__split_t , __shared_state<_CvrefSender, _Env>* ) //
395419 ->__box<_CvrefSender, _Env, true >;
396420
397421 template <class _CvrefSender , class _Env >
398- __box (
399- __ensure_started::__ensure_started_t ,
400- __intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2 >) -> __box<_CvrefSender, _Env, false >;
422+ __box (__ensure_started::__ensure_started_t , __shared_state<_CvrefSender, _Env>*)
423+ -> __box<_CvrefSender, _Env, false >;
401424
402425 template <class _Tag >
403426 struct __shared_impl : __sexpr_defaults {
@@ -419,14 +442,13 @@ namespace stdexec {
419442 []<class _Sender , class _Receiver >(
420443 __local_state<_Sender, _Receiver>& __self,
421444 _Receiver& __rcvr) noexcept -> void {
422- using __sh_state_t = typename __local_state<_Sender, _Receiver>::__sh_state_t ;
423445 // Scenario: there are no more split senders, this is the only operation state, the
424446 // underlying operation has not yet been started, and the receiver's stop token is already
425447 // in the "stop requested" state. Then registering the stop callback will call
426- // __on_stop_request on __self synchronously. It may also be called asynchronously at
427- // any point after the callback is registered. Beware. We are guaranteed, however, that
428- // __on_stop_request will not complete the operation or decrement the shared state's ref
429- // count until after __self has been added to the waiters list.
448+ // __local_state::operator() on __self synchronously. It may also be called asynchronously
449+ // at any point after the callback is registered. Beware. We are guaranteed, however, that
450+ // __local_state::operator() will not complete the operation or decrement the shared state's
451+ // ref count until after __self has been added to the waiters list.
430452 const auto __stok = stdexec::get_stop_token (stdexec::get_env (__rcvr));
431453 __self.__on_stop_ .emplace (__stok, __self);
432454
@@ -446,7 +468,7 @@ namespace stdexec {
446468 // Otherwise, failed to add the waiter because of a stop-request.
447469 // Complete synchronously with set_stopped().
448470 __self.__on_stop_ .reset ();
449- __sh_state_t::__detach (__self.__sh_state_ );
471+ std::exchange (__self.__sh_state_ , nullptr )-> __detach ( );
450472 stdexec::set_stopped (static_cast <_Receiver&&>(__rcvr));
451473 };
452474 };
0 commit comments