Skip to content

Commit b883077

Browse files
authored
Merge pull request NVIDIA#1740 from ericniebler/fix-multi-gpu-stream-scheduler-get-completion-scheduler-query
fix code that was C-style casting a multi_`gpu_stream_scheduler` to a `stream_scheduler`
2 parents aee092f + 74589f0 commit b883077

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

include/nvexec/multi_gpu_context.cuh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)
2727

2828
namespace nvexec {
2929
namespace _strm {
30-
struct multi_gpu_stream_scheduler : private stream_scheduler_env {
30+
struct multi_gpu_stream_scheduler : private stream_scheduler_env<multi_gpu_stream_scheduler> {
3131
using __t = multi_gpu_stream_scheduler;
3232
using __id = multi_gpu_stream_scheduler;
3333

@@ -138,6 +138,13 @@ namespace nvexec {
138138
int num_devices_{};
139139
context_state_t context_state_;
140140
};
141+
142+
template <>
143+
STDEXEC_ATTRIBUTE(nodiscard)
144+
inline auto stream_scheduler_env<multi_gpu_stream_scheduler>::query(
145+
get_completion_scheduler_t<set_value_t>) const noexcept -> multi_gpu_stream_scheduler {
146+
return stdexec::__c_downcast<multi_gpu_stream_scheduler>(*this);
147+
}
141148
} // namespace _strm
142149

143150
using _strm::multi_gpu_stream_scheduler;

include/nvexec/stream_context.cuh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,23 @@ namespace nvexec {
4545
namespace _strm {
4646
struct stream_scheduler;
4747

48-
struct stream_scheduler_env {
48+
template <class StreamScheduler>
49+
struct stream_scheduler_env { // NOLINT(bugprone-crtp-constructor-accessibility)
4950
STDEXEC_ATTRIBUTE(nodiscard)
5051
static auto query(get_forward_progress_guarantee_t) noexcept -> forward_progress_guarantee {
5152
return forward_progress_guarantee::weakly_parallel;
5253
}
5354

5455
STDEXEC_ATTRIBUTE(nodiscard)
55-
auto query(get_completion_scheduler_t<set_value_t>) const noexcept -> stream_scheduler;
56+
auto query(get_completion_scheduler_t<set_value_t>) const noexcept -> StreamScheduler;
5657

5758
STDEXEC_ATTRIBUTE(nodiscard)
5859
constexpr auto query(get_completion_domain_t<set_value_t>) const noexcept -> stream_domain {
5960
return {};
6061
}
6162
};
6263

63-
struct stream_scheduler : private stream_scheduler_env {
64+
struct stream_scheduler : private stream_scheduler_env<stream_scheduler> {
6465
using __t = stream_scheduler;
6566
using __id = stream_scheduler;
6667

@@ -153,10 +154,11 @@ namespace nvexec {
153154
context_state_t context_state_;
154155
};
155156

157+
template <>
156158
STDEXEC_ATTRIBUTE(nodiscard)
157-
inline auto stream_scheduler_env::query(get_completion_scheduler_t<set_value_t>) const noexcept
158-
-> stream_scheduler {
159-
return (const stream_scheduler&) *this;
159+
inline auto stream_scheduler_env<stream_scheduler>::query(
160+
get_completion_scheduler_t<set_value_t>) const noexcept -> stream_scheduler {
161+
return stdexec::__c_downcast<stream_scheduler>(*this);
160162
}
161163
} // namespace _strm
162164

0 commit comments

Comments
 (0)