Skip to content

Commit 150a82f

Browse files
authored
Merge pull request NVIDIA#1454 from maikel/any-sender-improve-overloads
Relax overload resolution of any receivers completion functions
2 parents e6fb836 + a762e9b commit 150a82f

File tree

3 files changed

+39
-24
lines changed

3 files changed

+39
-24
lines changed

include/exec/any_sender_of.hpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,7 @@ namespace exec {
439439
(*__other.__vtable_)(__copy_construct, this, __other);
440440
}
441441

442-
auto operator=(const __t& __other) -> __t&
443-
requires(_Copyable)
444-
{
442+
auto operator=(const __t& __other) -> __t& requires(_Copyable) {
445443
if (&__other != this) {
446444
__t tmp(__other);
447445
*this = std::move(tmp);
@@ -615,6 +613,7 @@ namespace exec {
615613
, public __query_vfun<_Queries>... {
616614
public:
617615
using __query_vfun<_Queries>::operator()...;
616+
using __any_::__rcvr_vfun<_Sigs>::operator()...;
618617

619618
private:
620619
template <class _Rcvr>
@@ -674,24 +673,21 @@ namespace exec {
674673
}
675674

676675
template <class... _As>
677-
requires __one_of<set_value_t(_As...), _Sigs...>
676+
requires __callable<__vtable_t, void*, set_value_t, _As...>
678677
void set_value(_As&&... __as) noexcept {
679-
const __any_::__rcvr_vfun<set_value_t(_As...)>* __vfun = __env_.__vtable_;
680-
(*__vfun->__complete_)(__env_.__rcvr_, static_cast<_As&&>(__as)...);
678+
(*__env_.__vtable_)(__env_.__rcvr_, set_value_t(), static_cast<_As&&>(__as)...);
681679
}
682680

683681
template <class _Error>
684-
requires __one_of<set_error_t(_Error), _Sigs...>
682+
requires __callable<__vtable_t, void*, set_error_t, _Error>
685683
void set_error(_Error&& __err) noexcept {
686-
const __any_::__rcvr_vfun<set_error_t(_Error)>* __vfun = __env_.__vtable_;
687-
(*__vfun->__complete_)(__env_.__rcvr_, static_cast<_Error&&>(__err));
684+
(*__env_.__vtable_)(__env_.__rcvr_, set_error_t(), static_cast<_Error&&>(__err));
688685
}
689686

690687
void set_stopped() noexcept
691-
requires __one_of<set_stopped_t(), _Sigs...>
688+
requires __callable<__vtable_t, void*, set_stopped_t>
692689
{
693-
const __any_::__rcvr_vfun<set_stopped_t()>* __vfun = __env_.__vtable_;
694-
(*__vfun->__complete_)(__env_.__rcvr_);
690+
(*__env_.__vtable_)(__env_.__rcvr_, set_stopped_t());
695691
}
696692

697693
auto get_env() const noexcept -> const __env_t& {

include/stdexec/__detail/__receiver_ref.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ namespace stdexec { namespace __any_ {
3030

3131
template <class _Tag, class... _Args>
3232
struct __rcvr_vfun<_Tag(_Args...)> {
33-
void (*__complete_)(void*, _Args&&...) noexcept;
33+
void (*__complete_)(void*, _Args...) noexcept;
3434

35-
void operator()(void* __obj, _Tag, _Args&&... __args) const noexcept {
35+
void operator()(void* __obj, _Tag, _Args... __args) const noexcept {
3636
__complete_(__obj, static_cast<_Args&&>(__args)...);
3737
}
3838
};
3939

4040
template <class _GetReceiver = std::identity, class _Obj, class _Tag, class... _Args>
4141
constexpr auto __rcvr_vfun_fn(_Obj*, _Tag (*)(_Args...)) noexcept {
42-
return +[](void* __ptr, _Args&&... __args) noexcept {
42+
return +[](void* __ptr, _Args... __args) noexcept {
4343
_Obj* __obj = static_cast<_Obj*>(__ptr);
4444
_Tag()(std::move(_GetReceiver()(*__obj)), static_cast<_Args&&>(__args)...);
4545
};
@@ -95,16 +95,20 @@ namespace stdexec { namespace __any_ {
9595
}
9696

9797
template <class... _As>
98+
requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_value_t, _As...>
9899
void set_value(_As&&... __as) noexcept {
99100
(*__vtable_)(__op_state_, set_value_t(), static_cast<_As&&>(__as)...);
100101
}
101102

102103
template <class _Error>
104+
requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_error_t, _Error>
103105
void set_error(_Error&& __err) noexcept {
104106
(*__vtable_)(__op_state_, set_error_t(), static_cast<_Error&&>(__err));
105107
}
106108

107-
void set_stopped() noexcept {
109+
void set_stopped() noexcept
110+
requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_stopped_t>
111+
{
108112
(*__vtable_)(__op_state_, set_stopped_t());
109113
}
110114

test/exec/test_any_sender.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,9 @@ namespace {
259259
TEST_CASE("sync_wait works on any_sender_of", "[types][any_sender]") {
260260
int value = 0;
261261
any_sender_of<set_value_t()> sender = just(42) | then([&](int v) noexcept { value = v; });
262-
CHECK(
263-
std::same_as<
264-
completion_signatures_of_t<any_sender_of<set_value_t()>>,
265-
completion_signatures<set_value_t()>>);
262+
CHECK(std::same_as<
263+
completion_signatures_of_t<any_sender_of<set_value_t()>>,
264+
completion_signatures<set_value_t()>>);
266265
sync_wait(std::move(sender));
267266
CHECK(value == 42);
268267
}
@@ -276,10 +275,9 @@ namespace {
276275

277276
TEST_CASE("sync_wait returns value", "[types][any_sender]") {
278277
any_sender_of<set_value_t(int)> sender = just(21) | then([&](int v) noexcept { return 2 * v; });
279-
CHECK(
280-
std::same_as<
281-
completion_signatures_of_t<any_sender_of<set_value_t(int)>>,
282-
completion_signatures<set_value_t(int)>>);
278+
CHECK(std::same_as<
279+
completion_signatures_of_t<any_sender_of<set_value_t(int)>>,
280+
completion_signatures<set_value_t(int)>>);
283281
auto [value1] = *sync_wait(std::move(sender));
284282
CHECK(value1 == 42);
285283
}
@@ -330,6 +328,23 @@ namespace {
330328
}
331329
}
332330

331+
template <class... Vals>
332+
using my_stoppable_sender_of =
333+
any_sender_of<set_value_t(Vals)..., set_error_t(std::exception_ptr), set_stopped_t()>;
334+
335+
TEST_CASE("any_sender uses overload rules for completion signatures", "[types][any_sender]") {
336+
auto split_sender = split(just(42));
337+
static_assert(sender_of<decltype(split_sender), set_error_t(const std::exception_ptr&)>);
338+
static_assert(sender_of<decltype(split_sender), set_value_t(const int&)>);
339+
my_stoppable_sender_of<int> sender = split_sender;
340+
341+
auto [value] = *sync_wait(std::move(sender));
342+
CHECK(value == 42);
343+
344+
sender = just(21) | then([&](int) -> int { throw 420; });
345+
CHECK_THROWS_AS(sync_wait(std::move(sender)), int);
346+
}
347+
333348
class stopped_token {
334349
private:
335350
bool stopped_{true};

0 commit comments

Comments
 (0)