diff --git a/cpp/include/rapidsmpf/memory/memory_type.hpp b/cpp/include/rapidsmpf/memory/memory_type.hpp index 485c63f75..01b75d9a6 100644 --- a/cpp/include/rapidsmpf/memory/memory_type.hpp +++ b/cpp/include/rapidsmpf/memory/memory_type.hpp @@ -6,6 +6,8 @@ #include #include +#include +#include namespace rapidsmpf { @@ -38,6 +40,31 @@ constexpr std::array SPILL_TARGET_MEMORY_TYPES{ {MemoryType::PINNED_HOST, MemoryType::HOST} }; +/** + * @brief Get the memory types with preference lower than or equal to @p mem_type. + * + * The returned span reflects the predefined ordering used in \c MEMORY_TYPES, + * which lists memory types in decreasing order of preference. + * + * @param mem_type The memory type used as the starting point. + * @return A span of memory types whose preference is lower than or equal to + * the given type. + */ +constexpr std::span leq_memory_types(MemoryType mem_type) noexcept { + return std::views::drop_while(MEMORY_TYPES, [&](MemoryType const& mt) { + return mt != mem_type; + }); +} + +static_assert(std::ranges::equal(leq_memory_types(MemoryType::DEVICE), MEMORY_TYPES)); +static_assert(std::ranges::equal( + leq_memory_types(MemoryType::HOST), std::ranges::single_view{MemoryType::HOST} +)); +// unknown memory type should return an empty view +static_assert(std::ranges::equal( + leq_memory_types(static_cast(-1)), std::ranges::empty_view{} +)); + /** * @brief Get the name of a MemoryType. * diff --git a/cpp/include/rapidsmpf/streaming/core/context.hpp b/cpp/include/rapidsmpf/streaming/core/context.hpp index 8c4fb914a..710c49e7c 100644 --- a/cpp/include/rapidsmpf/streaming/core/context.hpp +++ b/cpp/include/rapidsmpf/streaming/core/context.hpp @@ -138,6 +138,13 @@ class Context { std::size_t buffer_size ) const noexcept; + /** + * @brief Returns the options. + * + * @return The Options instance. + */ + [[nodiscard]] config::Options const& options() const noexcept; + private: config::Options options_; std::shared_ptr comm_; diff --git a/cpp/include/rapidsmpf/streaming/core/spillable_messages.hpp b/cpp/include/rapidsmpf/streaming/core/spillable_messages.hpp index ade43c088..ccc0ee371 100644 --- a/cpp/include/rapidsmpf/streaming/core/spillable_messages.hpp +++ b/cpp/include/rapidsmpf/streaming/core/spillable_messages.hpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -122,6 +123,16 @@ class SpillableMessages { */ std::map get_content_descriptions() const; + /** + * @brief Get the content description of a message by ID. + * + * @param mid Message identifier. + * @return Content description of the message. + * + * @throws std::out_of_range If the message does not exist. + */ + ContentDescription get_content_description(MessageId mid) const; + private: /** * @brief Thread-safe item containing a `Message`. diff --git a/cpp/src/streaming/core/context.cpp b/cpp/src/streaming/core/context.cpp index 1714ad011..fcda3b4cd 100644 --- a/cpp/src/streaming/core/context.cpp +++ b/cpp/src/streaming/core/context.cpp @@ -156,6 +156,10 @@ std::shared_ptr Context::create_bounded_queue( return std::shared_ptr(new BoundedQueue(buffer_size)); } +config::Options const& Context::options() const noexcept { + return options_; +} + std::shared_ptr Context::spillable_messages() const noexcept { return spillable_messages_; } diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index e28fc1755..e5f58a496 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -13,32 +14,13 @@ #include #include #include +#include #include namespace rapidsmpf::streaming::node { namespace { -/** - * @brief Returns the memory types to consider when allocating an output message. - * - * The returned view begins at the principal memory type of the input message - * and continues through the remaining types in `MEMORY_TYPES` in order of - * preference. This ensures we never allocate in a higher memory tier than the - * message's principal type. For example, if a message has been spilled and its - * principal type is `HOST`, only `HOST` will be considered. - * - * @param msg The message whose content determines the memory type order. - * - * @return A view of memory types to try for allocation, starting at the - * principal memory type. - */ -constexpr std::span get_output_memory_types(Message const& msg) { - auto const principal = msg.content_description().principal_memory_type(); - return MEMORY_TYPES - | std::views::drop_while([principal](MemoryType m) { return m != principal; }); -} - /** * @brief Asynchronously send a message to multiple output channels. * @@ -56,7 +38,9 @@ Node send_to_channels( size_t msg_sz_, Channel& ch_) -> coro::task { co_await ctx_.executor()->schedule(); - auto res = ctx_.br()->reserve_or_fail(msg_sz_, get_output_memory_types(msg_)); + auto const& cd = msg_.content_description(); + auto const mem_types = leq_memory_types(cd.principal_memory_type()); + auto res = ctx_.br()->reserve_or_fail(msg_sz_, mem_types); co_return co_await ch_.send(msg_.copy(res)); }; @@ -196,7 +180,8 @@ struct UnboundedFanout { * @brief Send messages to multiple output channels. * * @param ctx The context to use. - * @param self_next_idx Next index to send for the current channel + * @param self_next_idx Next index to send for the current channel (passed by ref + * because it needs to be updated) * @param ch_out The output channel to send messages to. * @return A coroutine representing the task. */ @@ -207,35 +192,40 @@ struct UnboundedFanout { }; co_await ctx.executor()->schedule(); + auto spillable_messages = ctx.spillable_messages(); + size_t n_available_messages = 0; - std::vector> messages_to_send; + std::vector msg_ids_to_send; while (true) { { auto lock = co_await mtx.scoped_lock(); co_await data_ready.wait(lock, [&] { // irrespective of no_more_input, update the end_idx to the total // number of messages - n_available_messages = recv_messages.size(); + n_available_messages = recv_msg_ids.size(); return no_more_input || self_next_idx < n_available_messages; }); if (no_more_input && self_next_idx == n_available_messages) { // no more messages will be received, and all messages have been sent break; } - // stash msg references under the lock - messages_to_send.reserve(n_available_messages - self_next_idx); - for (size_t i = self_next_idx; i < n_available_messages; i++) { - messages_to_send.emplace_back(recv_messages[i]); - } + // copy msg ids to send under the lock + msg_ids_to_send.reserve(n_available_messages - self_next_idx); + std::ranges::copy( + std::ranges::drop_view( + recv_msg_ids, static_cast(self_next_idx) + ), + std::back_inserter(msg_ids_to_send) + ); } - for (auto const& msg : messages_to_send) { - RAPIDSMPF_EXPECTS(!msg.get().empty(), "message cannot be empty"); - - auto res = ctx.br()->reserve_or_fail( - msg.get().copy_cost(), get_output_memory_types(msg.get()) - ); - if (!co_await ch_out->send(msg.get().copy(res))) { + for (auto const msg_id : msg_ids_to_send) { + auto const cd = spillable_messages->get_content_description(msg_id); + // Reserve memory for the output using the input message's memory type, or + // a lower-priority type if needed. + auto const mem_types = leq_memory_types(cd.principal_memory_type()); + auto res = ctx.br()->reserve_or_fail(cd.content_size(), mem_types); + if (!co_await ch_out->send(spillable_messages->copy(msg_id, res))) { // Failed to send message. Could be that the channel is shut down. // So we need to abort the send task, and notify the process input // task @@ -243,13 +233,13 @@ struct UnboundedFanout { co_return; } } - messages_to_send.clear(); + msg_ids_to_send.clear(); // now next_idx can be updated to end_idx, and if !no_more_input, we need to // request the recv task for more data auto lock = co_await mtx.scoped_lock(); self_next_idx = n_available_messages; - if (self_next_idx == recv_messages.size()) { + if (self_next_idx == recv_msg_ids.size()) { if (no_more_input) { // no more messages will be received, and all messages have been sent break; @@ -313,7 +303,7 @@ struct UnboundedFanout { per_ch_processed_min = *min_it; per_ch_processed_max = *max_it; - return per_ch_processed_max == recv_messages.size(); + return per_ch_processed_max == recv_msg_ids.size(); }); co_return std::make_pair(per_ch_processed_min, per_ch_processed_max); @@ -334,6 +324,10 @@ struct UnboundedFanout { // index of the first message to purge size_t purge_idx = 0; + // To make staged input messages spillable, we insert them into the Context's + // spillable_messages container while they are in transit. + auto spillable_messages = ctx.spillable_messages(); + // no_more_input is only set by this task, so reading without lock is safe here while (!no_more_input) { auto [per_ch_processed_min, per_ch_processed_max] = @@ -351,7 +345,7 @@ struct UnboundedFanout { if (msg.empty()) { no_more_input = true; } else { - recv_messages.emplace_back(std::move(msg)); + recv_msg_ids.emplace_back(spillable_messages->insert(std::move(msg))); } } @@ -362,7 +356,7 @@ struct UnboundedFanout { // However the deque is not resized. This guarantees that the indices are not // invalidated. while (purge_idx < per_ch_processed_min) { - recv_messages[purge_idx].reset(); + std::ignore = spillable_messages->extract(recv_msg_ids[purge_idx]); purge_idx++; } } @@ -383,7 +377,7 @@ struct UnboundedFanout { /// @brief messages received from the input channel. Using a deque to avoid /// invalidating references by reallocations. - std::deque recv_messages; + std::deque recv_msg_ids; /// @brief number of messages processed for each channel (ie. next index to send for /// each channel) diff --git a/cpp/src/streaming/core/spillable_messages.cpp b/cpp/src/streaming/core/spillable_messages.cpp index 39ad9efc4..1fc6855ee 100644 --- a/cpp/src/streaming/core/spillable_messages.cpp +++ b/cpp/src/streaming/core/spillable_messages.cpp @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - #include namespace rapidsmpf::streaming { @@ -98,4 +97,18 @@ SpillableMessages::get_content_descriptions() const { std::unique_lock global_lock(global_mutex_); return content_descriptions_; } + +ContentDescription rapidsmpf::streaming::SpillableMessages::get_content_description( + MessageId mid +) const { + std::lock_guard global_lock(global_mutex_); + auto it = content_descriptions_.find(mid); + RAPIDSMPF_EXPECTS( + it != content_descriptions_.end(), + "message not found " + std::to_string(mid), + std::out_of_range + ); + return it->second; +} + } // namespace rapidsmpf::streaming diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 74a2420be..f2d2d8e03 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -47,6 +48,47 @@ std::vector make_int_inputs(int n) { return inputs; } +/** + * @brief Helper to make a sequence of Messages where each buffer contains 1024 + * values of int [i, i + 1024). + */ +std::vector make_buffer_inputs(int n, rapidsmpf::BufferResource& br) { + std::vector inputs; + inputs.reserve(n); + + Message::CopyCallback copy_cb = [&](Message const& msg, MemoryReservation& res) { + rmm::cuda_stream_view stream = br.stream_pool().get_stream(); + auto const cd = msg.content_description(); + auto buf_cpy = br.allocate(cd.content_size(), stream, res); + // cd needs to be updated to reflect the new buffer + ContentDescription new_cd{ + {{buf_cpy->mem_type(), buf_cpy->size}}, ContentDescription::Spillable::YES + }; + rapidsmpf::buffer_copy(*buf_cpy, msg.get(), cd.content_size()); + return Message{ + msg.sequence_number(), std::move(buf_cpy), std::move(new_cd), msg.copy_cb() + }; + }; + for (int i = 0; i < n; ++i) { + std::vector values(1024, 0); + std::iota(values.begin(), values.end(), i); + rmm::cuda_stream_view stream = br.stream_pool().get_stream(); + // allocate outside of buffer resource + auto buffer = br.move( + std::make_unique( + values.data(), values.size() * sizeof(int), stream + ), + stream + ); + ContentDescription cd{ + std::ranges::single_view{std::pair{MemoryType::DEVICE, 1024 * sizeof(int)}}, + ContentDescription::Spillable::YES + }; + inputs.emplace_back(i, std::move(buffer), cd, copy_cb); + } + return inputs; +} + std::string policy_to_string(FanoutPolicy policy) { switch (policy) { case FanoutPolicy::BOUNDED: @@ -139,7 +181,53 @@ TEST_P(StreamingFanout, SinkPerChannel) { // object for (int i = 0; i < num_msgs; ++i) { SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); - EXPECT_EQ(outs[c][i].get(), i); + EXPECT_EQ(i, outs[c][i].get()); + } + } +} + +TEST_P(StreamingFanout, SinkPerChannel_Buffer) { + auto inputs = make_buffer_inputs(num_msgs, *ctx->br()); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i])); + } + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + // Validate sizes + EXPECT_EQ(outs[c].size(), static_cast(num_msgs)); + + // Validate ordering/content and that shallow copies share the same underlying + // object + for (int i = 0; i < num_msgs; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + auto const& buf = outs[c][i].get(); + EXPECT_EQ(1024 * sizeof(int), buf.size); + + std::vector recv(1024); + buf.stream().synchronize(); + RAPIDSMPF_CUDA_TRY( + cudaMemcpy(recv.data(), buf.data(), 1024 * sizeof(int), cudaMemcpyDefault) + ); + + EXPECT_TRUE(std::ranges::equal(std::ranges::views::iota(i, i + 1024), recv)); } } } @@ -447,3 +535,58 @@ TEST_P(ManyInputSinkStreamingFanout, ChannelOrder) { TEST_P(ManyInputSinkStreamingFanout, MessageOrder) { EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::MESSAGE_ORDER)); } + +class SpillingStreamingFanout : public BaseStreamingFixture { + void SetUp() override { + SetUpWithThreads(4); + + // override br and context with no device memory + std::unordered_map memory_available = + { + {MemoryType::DEVICE, []() -> std::int64_t { return 0; }}, + }; + br = std::make_shared( + mr_cuda, rapidsmpf::PinnedMemoryResource::Disabled, memory_available + ); + auto options = ctx->options(); + ctx = std::make_shared( + options, GlobalEnvironment->comm_, br + ); + } +}; + +TEST_F(SpillingStreamingFanout, Spilling) { + auto inputs = make_buffer_inputs(100, *ctx->br()); + constexpr int num_out_chs = 4; + constexpr FanoutPolicy policy = FanoutPolicy::UNBOUNDED; + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.push_back(node::fanout(ctx, in, out_chs, policy)); + nodes.push_back( + many_input_sink(ctx, out_chs, ConsumePolicy::CHANNEL_ORDER, outs) + ); + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + SCOPED_TRACE("channel " + std::to_string(c)); + // all messages should be in host memory + EXPECT_TRUE(std::ranges::all_of(outs[c], [](const Message& m) { + auto const cd = m.content_description(); + return cd.principal_memory_type() == MemoryType::HOST + && cd.content_size() == 1024 * sizeof(int); + })); + } +}