Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3876116
adding spillable messages
nirandaperera Dec 4, 2025
dc13cdf
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 4, 2025
d4e9a1b
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 5, 2025
7163773
adding test
nirandaperera Dec 5, 2025
dd0d345
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 5, 2025
ca10c57
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 8, 2025
363b61e
new API
nirandaperera Dec 8, 2025
feb04a8
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 8, 2025
72e0be3
merge conflict
nirandaperera Dec 8, 2025
5606f6a
enabling extracting message ID
nirandaperera Dec 9, 2025
ea830ac
minor changes
nirandaperera Dec 9, 2025
f9947c1
Merge branch 'main' into Make-unbounded-fanout-state-spillable
nirandaperera Dec 9, 2025
0d1cc9a
Update cpp/include/rapidsmpf/memory/memory_type.hpp
nirandaperera Dec 9, 2025
d5db515
doxygen error
nirandaperera Dec 9, 2025
61ba76e
Revert "enabling extracting message ID"
nirandaperera Dec 10, 2025
090dfde
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 10, 2025
4488ded
Merge branch 'Make-unbounded-fanout-state-spillable' of github.com:ni…
nirandaperera Dec 10, 2025
a2ad554
remove cout
nirandaperera Dec 10, 2025
e3958a2
Apply suggestions from code review
nirandaperera Dec 11, 2025
5dc7708
addressing comments
nirandaperera Dec 11, 2025
b4c6cac
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 11, 2025
55c8114
API change
nirandaperera Dec 11, 2025
d0e6a50
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 11, 2025
6298b2e
precommit
nirandaperera Dec 11, 2025
a8c232e
minro change
nirandaperera Dec 11, 2025
c4d4a39
Merge branch 'main' into Make-unbounded-fanout-state-spillable
nirandaperera Dec 15, 2025
1373521
fix build
nirandaperera Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cpp/include/rapidsmpf/memory/memory_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <array>
#include <ostream>
#include <ranges>
#include <span>

namespace rapidsmpf {

Expand Down Expand Up @@ -33,6 +35,27 @@ constexpr std::array<char const*, MEMORY_TYPES.size()> MEMORY_TYPE_NAMES{
*/
constexpr std::array<MemoryType, 1> SPILL_TARGET_MEMORY_TYPES{{MemoryType::HOST}};

/**
* @brief Get the lower memory types than or equal to the @p mem_type .
*
* @param mem_type The memory type.
* @return A span of the lower memory types than the given memory type.
*/
constexpr std::span<MemoryType const> leq_memory_types(MemoryType mem_type) noexcept {
return std::views::drop_while(MEMORY_TYPES, [&](MemoryType const& mt) {
return mt != mem_type;
});
}

static_assert(std::ranges::equal(leq_memory_types(MemoryType::DEVICE), MEMORY_TYPES));
static_assert(std::ranges::equal(
leq_memory_types(MemoryType::HOST), std::ranges::single_view{MemoryType::HOST}
));
// unknown memory type should return an empty view
static_assert(std::ranges::equal(
leq_memory_types(static_cast<MemoryType>(-1)), std::ranges::empty_view<MemoryType>{}
));
Comment on lines +53 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not immediately just have:

template<typename Pred>
constexpr std::span<MemoryType const> filter_types(MemoryType mem_type, Pred pred) noexcept {
    return std::views::filter(MEMORY_TYPES, pred);
}

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would return a Range isnt it? But yes, sure.


/**
* @brief Get the name of a MemoryType.
*
Expand Down
15 changes: 13 additions & 2 deletions cpp/include/rapidsmpf/streaming/core/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,22 @@ class Channel {
*
* @return A coroutine that evaluates to the message, which will be empty if the
* channel is shut down.
*
* @throws std::logic_error If the received message is empty.
*/
coro::task<Message> receive();

/**
* @brief Asynchronously receive a message id from the channel.
*
* Suspends if the channel is empty. Once the message id is received, the message can
* be extracted using the `SpillableMessages::extract` method. This could be useful
* when a node wants to consume a message from a channel, but leave it in the
* spillable messages container for later extraction.
*
* @return A coroutine that evaluates to the message id. If the channel is shut down,
* the message id will be `SpillableMessages::InvalidMessageId`.
*/
coro::task<SpillableMessages::MessageId> receive_message_id();

/**
* @brief Drains all pending messages from the channel and shuts it down.
*
Expand Down
7 changes: 7 additions & 0 deletions cpp/include/rapidsmpf/streaming/core/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ class Context {
std::size_t buffer_size
) const noexcept;

/**
* @brief Returns the options.
*
* @return The Options instance.
*/
[[nodiscard]] config::Options const& options() const noexcept;

private:
config::Options options_;
std::shared_ptr<Communicator> comm_;
Expand Down
14 changes: 14 additions & 0 deletions cpp/include/rapidsmpf/streaming/core/spillable_messages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <optional>
#include <unordered_map>

#include <rapidsmpf/memory/buffer_resource.hpp>
#include <rapidsmpf/memory/content_description.hpp>
#include <rapidsmpf/streaming/core/message.hpp>

Expand Down Expand Up @@ -40,6 +41,9 @@ class SpillableMessages {
/// @brief Unique identifier assigned to each message.
using MessageId = std::uint64_t;

/// @brief Invalid message identifier.
static constexpr MessageId InvalidMessageId = std::numeric_limits<MessageId>::max();

SpillableMessages() = default;
SpillableMessages(SpillableMessages const&) = delete;
SpillableMessages& operator=(SpillableMessages const&) = delete;
Expand Down Expand Up @@ -122,6 +126,16 @@ class SpillableMessages {
*/
std::map<MessageId, ContentDescription> get_content_descriptions() const;

/**
* @brief Get the content description of a message by ID.
*
* @param mid Message identifier.
* @return Content description of the message.
*
* @throws std::out_of_range If the message does not exist.
*/
ContentDescription get_content_description(MessageId mid) const;

private:
/**
* @brief Thread-safe item containing a `Message`.
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/streaming/core/channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ coro::task<Message> Channel::receive() {
}
}

coro::task<SpillableMessages::MessageId> Channel::receive_message_id() {
auto msg_id = co_await rb_.consume();
co_return msg_id.has_value() ? *msg_id : SpillableMessages::InvalidMessageId;
}

Node Channel::drain(std::unique_ptr<coro::thread_pool>& executor) {
return rb_.shutdown_drain(executor);
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/streaming/core/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ std::shared_ptr<BoundedQueue> Context::create_bounded_queue(
return std::shared_ptr<BoundedQueue>(new BoundedQueue(buffer_size));
}

config::Options const& Context::options() const noexcept {
return options_;
}

std::shared_ptr<SpillableMessages> Context::spillable_messages() const noexcept {
return spillable_messages_;
}
Expand Down
55 changes: 21 additions & 34 deletions cpp/src/streaming/core/fanout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,13 @@
#include <rapidsmpf/streaming/core/fanout.hpp>
#include <rapidsmpf/streaming/core/message.hpp>
#include <rapidsmpf/streaming/core/node.hpp>
#include <rapidsmpf/streaming/core/spillable_messages.hpp>

#include <coro/coro.hpp>

namespace rapidsmpf::streaming::node {
namespace {

/**
* @brief Returns the memory types to consider when allocating an output message.
*
* The returned view begins at the principal memory type of the input message
* and continues through the remaining types in `MEMORY_TYPES` in order of
* preference. This ensures we never allocate in a higher memory tier than the
* message's principal type. For example, if a message has been spilled and its
* principal type is `HOST`, only `HOST` will be considered.
*
* @param msg The message whose content determines the memory type order.
*
* @return A view of memory types to try for allocation, starting at the
* principal memory type.
*/
constexpr std::span<MemoryType const> get_output_memory_types(Message const& msg) {
auto const principal = msg.content_description().principal_memory_type();
return MEMORY_TYPES
| std::views::drop_while([principal](MemoryType m) { return m != principal; });
}

/**
* @brief Asynchronously send a message to multiple output channels.
*
Expand All @@ -56,7 +37,9 @@ Node send_to_channels(
size_t msg_sz_,
Channel& ch_) -> coro::task<bool> {
co_await ctx_.executor()->schedule();
auto res = ctx_.br()->reserve_or_fail(msg_sz_, get_output_memory_types(msg_));
auto const& cd = msg_.content_description();
auto const mem_types = leq_memory_types(cd.principal_memory_type());
auto res = ctx_.br()->reserve_or_fail(msg_sz_, mem_types);
co_return co_await ch_.send(msg_.copy(res));
};

Expand Down Expand Up @@ -207,8 +190,11 @@ struct UnboundedFanout {
};
co_await ctx.executor()->schedule();

auto spillable_messages = ctx.spillable_messages();

size_t n_available_messages = 0;
std::vector<std::reference_wrapper<Message>> messages_to_send;
std::vector<std::reference_wrapper<SpillableMessages::MessageId>>
messages_to_send;
while (true) {
{
auto lock = co_await mtx.scoped_lock();
Expand All @@ -230,12 +216,11 @@ struct UnboundedFanout {
}

for (auto const& msg : messages_to_send) {
RAPIDSMPF_EXPECTS(!msg.get().empty(), "message cannot be empty");

auto res = ctx.br()->reserve_or_fail(
msg.get().copy_cost(), get_output_memory_types(msg.get())
);
if (!co_await ch_out->send(msg.get().copy(res))) {
auto const& cd = spillable_messages->get_content_description(msg);
// try reserving into all memory types up to the highest memory type set
auto const mem_types = leq_memory_types(cd.principal_memory_type());
auto res = ctx.br()->reserve_or_fail(cd.content_size(), mem_types);
if (!co_await ch_out->send(spillable_messages->copy(msg, res))) {
// Failed to send message. Could be that the channel is shut down.
// So we need to abort the send task, and notify the process input
// task
Expand Down Expand Up @@ -334,6 +319,8 @@ struct UnboundedFanout {
// index of the first message to purge
size_t purge_idx = 0;

auto spillable_messages = ctx.spillable_messages();

// no_more_input is only set by this task, so reading without lock is safe here
while (!no_more_input) {
auto [per_ch_processed_min, per_ch_processed_max] =
Expand All @@ -343,15 +330,15 @@ struct UnboundedFanout {
break;
}

// receive a message from the input channel
auto msg = co_await ch_in->receive();
// receive a message id from the input channel
auto msg_id = co_await ch_in->receive_message_id();

{
auto lock = co_await mtx.scoped_lock();
if (msg.empty()) {
if (msg_id == SpillableMessages::InvalidMessageId) {
no_more_input = true;
} else {
recv_messages.emplace_back(std::move(msg));
recv_messages.emplace_back(msg_id);
}
}

Expand All @@ -362,7 +349,7 @@ struct UnboundedFanout {
// However the deque is not resized. This guarantees that the indices are not
// invalidated.
while (purge_idx < per_ch_processed_min) {
recv_messages[purge_idx].reset();
std::ignore = spillable_messages->extract(recv_messages[purge_idx]);
purge_idx++;
}
}
Expand All @@ -383,7 +370,7 @@ struct UnboundedFanout {

/// @brief messages received from the input channel. Using a deque to avoid
/// invalidating references by reallocations.
std::deque<Message> recv_messages;
std::deque<SpillableMessages::MessageId> recv_messages;

/// @brief number of messages processed for each channel (ie. next index to send for
/// each channel)
Expand Down
15 changes: 14 additions & 1 deletion cpp/src/streaming/core/spillable_messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/


#include <rapidsmpf/streaming/core/spillable_messages.hpp>

namespace rapidsmpf::streaming {
Expand Down Expand Up @@ -98,4 +97,18 @@ SpillableMessages::get_content_descriptions() const {
std::unique_lock global_lock(global_mutex_);
return content_descriptions_;
}

ContentDescription rapidsmpf::streaming::SpillableMessages::get_content_description(
MessageId mid
) const {
std::lock_guard global_lock(global_mutex_);
auto it = content_descriptions_.find(mid);
RAPIDSMPF_EXPECTS(
it != content_descriptions_.end(),
"message not found " + std::to_string(mid),
std::out_of_range
);
return it->second;
}

} // namespace rapidsmpf::streaming
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ if(RAPIDSMPF_HAVE_STREAMING)
target_sources(
test_sources
INTERFACE streaming/test_allgather.cpp
streaming/test_channel.cpp
streaming/test_error_handling.cpp
streaming/test_fanout.cpp
streaming/test_leaf_node.cpp
Expand Down
75 changes: 75 additions & 0 deletions cpp/tests/streaming/test_channel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <rapidsmpf/streaming/core/channel.hpp>
#include <rapidsmpf/streaming/core/context.hpp>
#include <rapidsmpf/streaming/core/node.hpp>

#include "base_streaming_fixture.hpp"

using namespace rapidsmpf;
using namespace rapidsmpf::streaming;

using StreamingChannel = BaseStreamingFixture;

TEST_F(StreamingChannel, ReceiveMessageId) {
constexpr int num_messages = 10;
auto ch = ctx->create_channel();
std::vector<Node> nodes;

nodes.push_back([](auto ctx, auto ch_out) -> Node {
ShutdownAtExit c{ch_out};
co_await ctx->executor()->schedule();
for (int i = 0; i < num_messages; ++i) {
co_await ch_out->send(
Message{
static_cast<uint64_t>(i),
std::make_unique<int>(i * 10),
ContentDescription{},
[](Message const& /* msg */, MemoryReservation& /* res */)
-> Message { RAPIDSMPF_FAIL("should not be called"); }
}
);
}
co_await ch_out->drain(ctx->executor());
}(ctx, ch));

std::vector<int> recv_vals;
std::vector<uint64_t> recv_seq_nums;
nodes.push_back(
[](auto ctx,
auto ch_in,
std::vector<int>& values,
std::vector<uint64_t>& seq_nums) -> Node {
ShutdownAtExit c{ch_in};
co_await ctx->executor()->schedule();
while (true) {
auto msg_id = co_await ch_in->receive_message_id();
if (msg_id == SpillableMessages::InvalidMessageId) {
break;
}
auto msg = ctx->spillable_messages()->extract(msg_id);
seq_nums.push_back(msg.sequence_number());
values.push_back(msg.template get<int>());
}
}(ctx, ch, recv_vals, recv_seq_nums)
);

run_streaming_pipeline(std::move(nodes));

// Verify all messages were received correctly.
EXPECT_EQ(num_messages, recv_vals.size());
EXPECT_EQ(num_messages, recv_seq_nums.size());

std::ranges::sort(recv_seq_nums);
std::ranges::sort(recv_vals);
for (int i = 0; i < num_messages; ++i) {
EXPECT_EQ(static_cast<uint64_t>(i), recv_seq_nums[i]);
EXPECT_EQ(i * 10, recv_vals[i]);
}
}
Loading