-
Notifications
You must be signed in to change notification settings - Fork 28
Make unbounded fanout messages spillable #711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
3876116
dc13cdf
d4e9a1b
7163773
dd0d345
ca10c57
363b61e
feb04a8
72e0be3
5606f6a
ea830ac
f9947c1
0d1cc9a
d5db515
61ba76e
090dfde
4488ded
a2ad554
e3958a2
5dc7708
b4c6cac
55c8114
d0e6a50
6298b2e
a8c232e
c4d4a39
1373521
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,8 @@ | |
|
|
||
| #include <array> | ||
| #include <ostream> | ||
| #include <ranges> | ||
| #include <span> | ||
|
|
||
| namespace rapidsmpf { | ||
|
|
||
|
|
@@ -33,6 +35,31 @@ constexpr std::array<char const*, MEMORY_TYPES.size()> MEMORY_TYPE_NAMES{ | |
| */ | ||
| constexpr std::array<MemoryType, 1> SPILL_TARGET_MEMORY_TYPES{{MemoryType::HOST}}; | ||
|
|
||
| /** | ||
| * @brief Get the memory types with preference lower than or equal to @p mem_type. | ||
| * | ||
| * The returned span reflects the predefined ordering used in \c MEMORY_TYPES, | ||
| * which lists memory types in decreasing order of preference. | ||
| * | ||
| * @param mem_type The memory type used as the starting point. | ||
| * @return A span of memory types whose preference is lower than or equal to | ||
| * the given type. | ||
| */ | ||
| constexpr std::span<MemoryType const> leq_memory_types(MemoryType mem_type) noexcept { | ||
| return std::views::drop_while(MEMORY_TYPES, [&](MemoryType const& mt) { | ||
| return mt != mem_type; | ||
| }); | ||
| } | ||
|
|
||
| static_assert(std::ranges::equal(leq_memory_types(MemoryType::DEVICE), MEMORY_TYPES)); | ||
| static_assert(std::ranges::equal( | ||
| leq_memory_types(MemoryType::HOST), std::ranges::single_view{MemoryType::HOST} | ||
| )); | ||
| // unknown memory type should return an empty view | ||
| static_assert(std::ranges::equal( | ||
| leq_memory_types(static_cast<MemoryType>(-1)), std::ranges::empty_view<MemoryType>{} | ||
| )); | ||
|
Comment on lines
+53
to
+66
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not immediately just have: ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would return a |
||
|
|
||
| /** | ||
| * @brief Get the name of a MemoryType. | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -47,6 +47,47 @@ std::vector<Message> make_int_inputs(int n) { | |||
| return inputs; | ||||
| } | ||||
|
|
||||
| /** | ||||
| * @brief Helper to make a sequence of Message<Buffer>s where each buffer contains 1024 | ||||
| * values of int [i, i + 1024). | ||||
| */ | ||||
| std::vector<Message> make_buffer_inputs(int n, rapidsmpf::BufferResource& br) { | ||||
| std::vector<Message> inputs; | ||||
| inputs.reserve(n); | ||||
|
|
||||
| Message::CopyCallback copy_cb = [&](Message const& msg, MemoryReservation& res) { | ||||
| rmm::cuda_stream_view stream = br.stream_pool().get_stream(); | ||||
| auto const& cd = msg.content_description(); | ||||
nirandaperera marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| auto buf_cpy = br.allocate(cd.content_size(), stream, res); | ||||
| // cd needs to be updated to reflect the new buffer | ||||
| ContentDescription new_cd{ | ||||
| {{buf_cpy->mem_type(), buf_cpy->size}}, ContentDescription::Spillable::YES | ||||
| }; | ||||
| rapidsmpf::buffer_copy(*buf_cpy, msg.get<Buffer>(), cd.content_size()); | ||||
| return Message{ | ||||
| msg.sequence_number(), std::move(buf_cpy), std::move(new_cd), msg.copy_cb() | ||||
| }; | ||||
| }; | ||||
| for (int i = 0; i < n; ++i) { | ||||
| std::vector<int> values(1024, 0); | ||||
| std::iota(values.begin(), values.end(), i); | ||||
| rmm::cuda_stream_view stream = br.stream_pool().get_stream(); | ||||
| // allocate outside of buffer resource | ||||
| auto buffer = br.move( | ||||
| std::make_unique<rmm::device_buffer>( | ||||
| values.data(), values.size() * sizeof(int), stream | ||||
| ), | ||||
| stream | ||||
| ); | ||||
| ContentDescription cd{ | ||||
| std::ranges::single_view{std::pair{MemoryType::DEVICE, 1024 * sizeof(int)}}, | ||||
| ContentDescription::Spillable::YES | ||||
| }; | ||||
| inputs.emplace_back(i, std::move(buffer), cd, copy_cb); | ||||
| } | ||||
| return inputs; | ||||
| } | ||||
|
|
||||
| std::string policy_to_string(FanoutPolicy policy) { | ||||
| switch (policy) { | ||||
| case FanoutPolicy::BOUNDED: | ||||
|
|
@@ -139,7 +180,53 @@ TEST_P(StreamingFanout, SinkPerChannel) { | |||
| // object | ||||
| for (int i = 0; i < num_msgs; ++i) { | ||||
| SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); | ||||
| EXPECT_EQ(outs[c][i].get<int>(), i); | ||||
| EXPECT_EQ(i, outs[c][i].get<int>()); | ||||
| } | ||||
| } | ||||
| } | ||||
|
|
||||
| TEST_P(StreamingFanout, SinkPerChannel_Buffer) { | ||||
| auto inputs = make_buffer_inputs(num_msgs, *ctx->br()); | ||||
|
|
||||
| std::vector<std::vector<Message>> outs(num_out_chs); | ||||
| { | ||||
| std::vector<Node> nodes; | ||||
|
|
||||
| auto in = ctx->create_channel(); | ||||
| nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); | ||||
|
|
||||
| std::vector<std::shared_ptr<Channel>> out_chs; | ||||
| for (int i = 0; i < num_out_chs; ++i) { | ||||
| out_chs.emplace_back(ctx->create_channel()); | ||||
| } | ||||
|
|
||||
| nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); | ||||
|
|
||||
| for (int i = 0; i < num_out_chs; ++i) { | ||||
| nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i])); | ||||
| } | ||||
|
|
||||
| run_streaming_pipeline(std::move(nodes)); | ||||
| } | ||||
|
|
||||
| for (int c = 0; c < num_out_chs; ++c) { | ||||
| // Validate sizes | ||||
| EXPECT_EQ(outs[c].size(), static_cast<size_t>(num_msgs)); | ||||
|
|
||||
| // Validate ordering/content and that shallow copies share the same underlying | ||||
| // object | ||||
| for (int i = 0; i < num_msgs; ++i) { | ||||
| SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); | ||||
| auto const& buf = outs[c][i].get<Buffer>(); | ||||
| EXPECT_EQ(1024 * sizeof(int), buf.size); | ||||
|
|
||||
| std::vector<int> recv(1024); | ||||
| buf.stream().synchronize(); | ||||
| RAPIDSMPF_CUDA_TRY( | ||||
| cudaMemcpy(recv.data(), buf.data(), 1024 * sizeof(int), cudaMemcpyDefault) | ||||
| ); | ||||
|
|
||||
| EXPECT_TRUE(std::ranges::equal(std::ranges::views::iota(i, i + 1024), recv)); | ||||
| } | ||||
| } | ||||
| } | ||||
|
|
@@ -447,3 +534,57 @@ TEST_P(ManyInputSinkStreamingFanout, ChannelOrder) { | |||
| TEST_P(ManyInputSinkStreamingFanout, MessageOrder) { | ||||
| EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::MESSAGE_ORDER)); | ||||
| } | ||||
|
|
||||
| class SpillingStreamingFanout : public BaseStreamingFixture { | ||||
| void SetUp() override { | ||||
| SetUpWithThreads(4); | ||||
|
|
||||
| // override br and context with no device memory | ||||
| std::unordered_map<MemoryType, BufferResource::MemoryAvailable> memory_available = | ||||
| { | ||||
| {MemoryType::DEVICE, []() -> std::int64_t { return 0; }}, | ||||
| }; | ||||
| br = std::make_shared<rapidsmpf::BufferResource>(mr_cuda, memory_available); | ||||
| auto options = ctx->options(); | ||||
| ctx = std::make_shared<rapidsmpf::streaming::Context>( | ||||
| options, GlobalEnvironment->comm_, br | ||||
| ); | ||||
| } | ||||
| }; | ||||
|
|
||||
| TEST_F(SpillingStreamingFanout, Spilling) { | ||||
| auto inputs = make_buffer_inputs(100, *ctx->br()); | ||||
| constexpr int num_out_chs = 4; | ||||
| constexpr FanoutPolicy policy = FanoutPolicy::UNBOUNDED; | ||||
|
|
||||
| std::vector<std::vector<Message>> outs(num_out_chs); | ||||
| { | ||||
| std::vector<Node> nodes; | ||||
|
|
||||
| auto in = ctx->create_channel(); | ||||
| nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); | ||||
|
|
||||
| std::vector<std::shared_ptr<Channel>> out_chs; | ||||
| for (int i = 0; i < num_out_chs; ++i) { | ||||
| out_chs.emplace_back(ctx->create_channel()); | ||||
| } | ||||
|
|
||||
| nodes.push_back(node::fanout(ctx, in, out_chs, policy)); | ||||
|
|
||||
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh oh!
There was an error while loading. Please reload this page.