Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3876116
adding spillable messages
nirandaperera Dec 4, 2025
dc13cdf
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 4, 2025
d4e9a1b
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 5, 2025
7163773
adding test
nirandaperera Dec 5, 2025
dd0d345
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 5, 2025
ca10c57
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 8, 2025
363b61e
new API
nirandaperera Dec 8, 2025
feb04a8
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 8, 2025
72e0be3
merge conflict
nirandaperera Dec 8, 2025
5606f6a
enabling extracting message ID
nirandaperera Dec 9, 2025
ea830ac
minor changes
nirandaperera Dec 9, 2025
f9947c1
Merge branch 'main' into Make-unbounded-fanout-state-spillable
nirandaperera Dec 9, 2025
0d1cc9a
Update cpp/include/rapidsmpf/memory/memory_type.hpp
nirandaperera Dec 9, 2025
d5db515
doxygen error
nirandaperera Dec 9, 2025
61ba76e
Revert "enabling extracting message ID"
nirandaperera Dec 10, 2025
090dfde
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 10, 2025
4488ded
Merge branch 'Make-unbounded-fanout-state-spillable' of github.com:ni…
nirandaperera Dec 10, 2025
a2ad554
remove cout
nirandaperera Dec 10, 2025
e3958a2
Apply suggestions from code review
nirandaperera Dec 11, 2025
5dc7708
addressing comments
nirandaperera Dec 11, 2025
b4c6cac
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 11, 2025
55c8114
API change
nirandaperera Dec 11, 2025
d0e6a50
Merge branch 'main' of github.com:rapidsai/rapidsmpf into Make-unboun…
nirandaperera Dec 11, 2025
6298b2e
precommit
nirandaperera Dec 11, 2025
a8c232e
minro change
nirandaperera Dec 11, 2025
c4d4a39
Merge branch 'main' into Make-unbounded-fanout-state-spillable
nirandaperera Dec 15, 2025
1373521
fix build
nirandaperera Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions cpp/include/rapidsmpf/memory/memory_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <array>
#include <ostream>
#include <ranges>
#include <span>

namespace rapidsmpf {

Expand Down Expand Up @@ -33,6 +35,31 @@ constexpr std::array<char const*, MEMORY_TYPES.size()> MEMORY_TYPE_NAMES{
*/
constexpr std::array<MemoryType, 1> SPILL_TARGET_MEMORY_TYPES{{MemoryType::HOST}};

/**
* @brief Get the memory types with preference lower than or equal to @p mem_type.
*
* The returned span reflects the predefined ordering used in \c MEMORY_TYPES,
* which lists memory types in decreasing order of preference.
*
* @param mem_type The memory type used as the starting point.
* @return A span of memory types whose preference is lower than or equal to
* the given type.
*/
constexpr std::span<MemoryType const> leq_memory_types(MemoryType mem_type) noexcept {
return std::views::drop_while(MEMORY_TYPES, [&](MemoryType const& mt) {
return mt != mem_type;
});
}

static_assert(std::ranges::equal(leq_memory_types(MemoryType::DEVICE), MEMORY_TYPES));
static_assert(std::ranges::equal(
leq_memory_types(MemoryType::HOST), std::ranges::single_view{MemoryType::HOST}
));
// unknown memory type should return an empty view
static_assert(std::ranges::equal(
leq_memory_types(static_cast<MemoryType>(-1)), std::ranges::empty_view<MemoryType>{}
));
Comment on lines +53 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not immediately just have:

template<typename Pred>
constexpr std::span<MemoryType const> filter_types(MemoryType mem_type, Pred pred) noexcept {
    return std::views::filter(MEMORY_TYPES, pred);
}

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would return a Range isnt it? But yes, sure.


/**
* @brief Get the name of a MemoryType.
*
Expand Down
7 changes: 7 additions & 0 deletions cpp/include/rapidsmpf/streaming/core/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ class Context {
std::size_t buffer_size
) const noexcept;

/**
* @brief Returns the options.
*
* @return The Options instance.
*/
[[nodiscard]] config::Options const& options() const noexcept;

private:
config::Options options_;
std::shared_ptr<Communicator> comm_;
Expand Down
11 changes: 11 additions & 0 deletions cpp/include/rapidsmpf/streaming/core/spillable_messages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <optional>
#include <unordered_map>

#include <rapidsmpf/memory/buffer_resource.hpp>
#include <rapidsmpf/memory/content_description.hpp>
#include <rapidsmpf/streaming/core/message.hpp>

Expand Down Expand Up @@ -122,6 +123,16 @@ class SpillableMessages {
*/
std::map<MessageId, ContentDescription> get_content_descriptions() const;

/**
* @brief Get the content description of a message by ID.
*
* @param mid Message identifier.
* @return Content description of the message.
*
* @throws std::out_of_range If the message does not exist.
*/
ContentDescription get_content_description(MessageId mid) const;

private:
/**
* @brief Thread-safe item containing a `Message`.
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/streaming/core/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ std::shared_ptr<BoundedQueue> Context::create_bounded_queue(
return std::shared_ptr<BoundedQueue>(new BoundedQueue(buffer_size));
}

config::Options const& Context::options() const noexcept {
return options_;
}

std::shared_ptr<SpillableMessages> Context::spillable_messages() const noexcept {
return spillable_messages_;
}
Expand Down
51 changes: 20 additions & 31 deletions cpp/src/streaming/core/fanout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,13 @@
#include <rapidsmpf/streaming/core/fanout.hpp>
#include <rapidsmpf/streaming/core/message.hpp>
#include <rapidsmpf/streaming/core/node.hpp>
#include <rapidsmpf/streaming/core/spillable_messages.hpp>

#include <coro/coro.hpp>

namespace rapidsmpf::streaming::node {
namespace {

/**
* @brief Returns the memory types to consider when allocating an output message.
*
* The returned view begins at the principal memory type of the input message
* and continues through the remaining types in `MEMORY_TYPES` in order of
* preference. This ensures we never allocate in a higher memory tier than the
* message's principal type. For example, if a message has been spilled and its
* principal type is `HOST`, only `HOST` will be considered.
*
* @param msg The message whose content determines the memory type order.
*
* @return A view of memory types to try for allocation, starting at the
* principal memory type.
*/
constexpr std::span<MemoryType const> get_output_memory_types(Message const& msg) {
auto const principal = msg.content_description().principal_memory_type();
return MEMORY_TYPES
| std::views::drop_while([principal](MemoryType m) { return m != principal; });
}

/**
* @brief Asynchronously send a message to multiple output channels.
*
Expand All @@ -56,7 +37,9 @@ Node send_to_channels(
size_t msg_sz_,
Channel& ch_) -> coro::task<bool> {
co_await ctx_.executor()->schedule();
auto res = ctx_.br()->reserve_or_fail(msg_sz_, get_output_memory_types(msg_));
auto const& cd = msg_.content_description();
auto const mem_types = leq_memory_types(cd.principal_memory_type());
auto res = ctx_.br()->reserve_or_fail(msg_sz_, mem_types);
co_return co_await ch_.send(msg_.copy(res));
};

Expand Down Expand Up @@ -207,8 +190,11 @@ struct UnboundedFanout {
};
co_await ctx.executor()->schedule();

auto spillable_messages = ctx.spillable_messages();

size_t n_available_messages = 0;
std::vector<std::reference_wrapper<Message>> messages_to_send;
std::vector<std::reference_wrapper<SpillableMessages::MessageId>>
messages_to_send;
while (true) {
{
auto lock = co_await mtx.scoped_lock();
Expand All @@ -230,12 +216,11 @@ struct UnboundedFanout {
}

for (auto const& msg : messages_to_send) {
RAPIDSMPF_EXPECTS(!msg.get().empty(), "message cannot be empty");

auto res = ctx.br()->reserve_or_fail(
msg.get().copy_cost(), get_output_memory_types(msg.get())
);
if (!co_await ch_out->send(msg.get().copy(res))) {
auto const& cd = spillable_messages->get_content_description(msg);
// try reserving into all memory types up to the highest memory type set
auto const mem_types = leq_memory_types(cd.principal_memory_type());
auto res = ctx.br()->reserve_or_fail(cd.content_size(), mem_types);
if (!co_await ch_out->send(spillable_messages->copy(msg, res))) {
// Failed to send message. Could be that the channel is shut down.
// So we need to abort the send task, and notify the process input
// task
Expand Down Expand Up @@ -334,6 +319,8 @@ struct UnboundedFanout {
// index of the first message to purge
size_t purge_idx = 0;

auto spillable_messages = ctx.spillable_messages();

// no_more_input is only set by this task, so reading without lock is safe here
while (!no_more_input) {
auto [per_ch_processed_min, per_ch_processed_max] =
Expand All @@ -351,7 +338,9 @@ struct UnboundedFanout {
if (msg.empty()) {
no_more_input = true;
} else {
recv_messages.emplace_back(std::move(msg));
recv_messages.emplace_back(
spillable_messages->insert(std::move(msg))
);
}
}

Expand All @@ -362,7 +351,7 @@ struct UnboundedFanout {
// However the deque is not resized. This guarantees that the indices are not
// invalidated.
while (purge_idx < per_ch_processed_min) {
recv_messages[purge_idx].reset();
std::ignore = spillable_messages->extract(recv_messages[purge_idx]);
purge_idx++;
}
}
Expand All @@ -383,7 +372,7 @@ struct UnboundedFanout {

/// @brief messages received from the input channel. Using a deque to avoid
/// invalidating references by reallocations.
std::deque<Message> recv_messages;
std::deque<SpillableMessages::MessageId> recv_messages;

/// @brief number of messages processed for each channel (ie. next index to send for
/// each channel)
Expand Down
15 changes: 14 additions & 1 deletion cpp/src/streaming/core/spillable_messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/


#include <rapidsmpf/streaming/core/spillable_messages.hpp>

namespace rapidsmpf::streaming {
Expand Down Expand Up @@ -98,4 +97,18 @@ SpillableMessages::get_content_descriptions() const {
std::unique_lock global_lock(global_mutex_);
return content_descriptions_;
}

ContentDescription rapidsmpf::streaming::SpillableMessages::get_content_description(
MessageId mid
) const {
std::lock_guard global_lock(global_mutex_);
auto it = content_descriptions_.find(mid);
RAPIDSMPF_EXPECTS(
it != content_descriptions_.end(),
"message not found " + std::to_string(mid),
std::out_of_range
);
return it->second;
}

} // namespace rapidsmpf::streaming
143 changes: 142 additions & 1 deletion cpp/tests/streaming/test_fanout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,47 @@ std::vector<Message> make_int_inputs(int n) {
return inputs;
}

/**
* @brief Helper to make a sequence of Message<Buffer>s where each buffer contains 1024
* values of int [i, i + 1024).
*/
std::vector<Message> make_buffer_inputs(int n, rapidsmpf::BufferResource& br) {
std::vector<Message> inputs;
inputs.reserve(n);

Message::CopyCallback copy_cb = [&](Message const& msg, MemoryReservation& res) {
rmm::cuda_stream_view stream = br.stream_pool().get_stream();
auto const& cd = msg.content_description();
auto buf_cpy = br.allocate(cd.content_size(), stream, res);
// cd needs to be updated to reflect the new buffer
ContentDescription new_cd{
{{buf_cpy->mem_type(), buf_cpy->size}}, ContentDescription::Spillable::YES
};
rapidsmpf::buffer_copy(*buf_cpy, msg.get<Buffer>(), cd.content_size());
return Message{
msg.sequence_number(), std::move(buf_cpy), std::move(new_cd), msg.copy_cb()
};
};
for (int i = 0; i < n; ++i) {
std::vector<int> values(1024, 0);
std::iota(values.begin(), values.end(), i);
rmm::cuda_stream_view stream = br.stream_pool().get_stream();
// allocate outside of buffer resource
auto buffer = br.move(
std::make_unique<rmm::device_buffer>(
values.data(), values.size() * sizeof(int), stream
),
stream
);
ContentDescription cd{
std::ranges::single_view{std::pair{MemoryType::DEVICE, 1024 * sizeof(int)}},
ContentDescription::Spillable::YES
};
inputs.emplace_back(i, std::move(buffer), cd, copy_cb);
}
return inputs;
}

std::string policy_to_string(FanoutPolicy policy) {
switch (policy) {
case FanoutPolicy::BOUNDED:
Expand Down Expand Up @@ -139,7 +180,53 @@ TEST_P(StreamingFanout, SinkPerChannel) {
// object
for (int i = 0; i < num_msgs; ++i) {
SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i));
EXPECT_EQ(outs[c][i].get<int>(), i);
EXPECT_EQ(i, outs[c][i].get<int>());
}
}
}

TEST_P(StreamingFanout, SinkPerChannel_Buffer) {
auto inputs = make_buffer_inputs(num_msgs, *ctx->br());

std::vector<std::vector<Message>> outs(num_out_chs);
{
std::vector<Node> nodes;

auto in = ctx->create_channel();
nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs)));

std::vector<std::shared_ptr<Channel>> out_chs;
for (int i = 0; i < num_out_chs; ++i) {
out_chs.emplace_back(ctx->create_channel());
}

nodes.emplace_back(node::fanout(ctx, in, out_chs, policy));

for (int i = 0; i < num_out_chs; ++i) {
nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i]));
}

run_streaming_pipeline(std::move(nodes));
}

for (int c = 0; c < num_out_chs; ++c) {
// Validate sizes
EXPECT_EQ(outs[c].size(), static_cast<size_t>(num_msgs));

// Validate ordering/content and that shallow copies share the same underlying
// object
for (int i = 0; i < num_msgs; ++i) {
SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i));
auto const& buf = outs[c][i].get<Buffer>();
EXPECT_EQ(1024 * sizeof(int), buf.size);

std::vector<int> recv(1024);
buf.stream().synchronize();
RAPIDSMPF_CUDA_TRY(
cudaMemcpy(recv.data(), buf.data(), 1024 * sizeof(int), cudaMemcpyDefault)
);

EXPECT_TRUE(std::ranges::equal(std::ranges::views::iota(i, i + 1024), recv));
}
}
}
Expand Down Expand Up @@ -447,3 +534,57 @@ TEST_P(ManyInputSinkStreamingFanout, ChannelOrder) {
TEST_P(ManyInputSinkStreamingFanout, MessageOrder) {
EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::MESSAGE_ORDER));
}

class SpillingStreamingFanout : public BaseStreamingFixture {
void SetUp() override {
SetUpWithThreads(4);

// override br and context with no device memory
std::unordered_map<MemoryType, BufferResource::MemoryAvailable> memory_available =
{
{MemoryType::DEVICE, []() -> std::int64_t { return 0; }},
};
br = std::make_shared<rapidsmpf::BufferResource>(mr_cuda, memory_available);
auto options = ctx->options();
ctx = std::make_shared<rapidsmpf::streaming::Context>(
options, GlobalEnvironment->comm_, br
);
}
};

TEST_F(SpillingStreamingFanout, Spilling) {
auto inputs = make_buffer_inputs(100, *ctx->br());
constexpr int num_out_chs = 4;
constexpr FanoutPolicy policy = FanoutPolicy::UNBOUNDED;

std::vector<std::vector<Message>> outs(num_out_chs);
{
std::vector<Node> nodes;

auto in = ctx->create_channel();
nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs)));

std::vector<std::shared_ptr<Channel>> out_chs;
for (int i = 0; i < num_out_chs; ++i) {
out_chs.emplace_back(ctx->create_channel());
}

nodes.push_back(node::fanout(ctx, in, out_chs, policy));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

nodes.push_back(
many_input_sink(ctx, out_chs, ConsumePolicy::CHANNEL_ORDER, outs)
);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

run_streaming_pipeline(std::move(nodes));
}

for (int c = 0; c < num_out_chs; ++c) {
SCOPED_TRACE("channel " + std::to_string(c));
// all messages should be in host memory
EXPECT_TRUE(std::ranges::all_of(outs[c], [](const Message& m) {
auto const& cd = m.content_description();
return cd.principal_memory_type() == MemoryType::HOST
&& cd.content_size() == 1024 * sizeof(int);
}));
}
}