Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
45e2d49
Rename device_uvector_policy -> device_container_policy and add non-i…
achirkin Feb 26, 2026
65d4570
Declare the new resources in raft handle
achirkin Feb 26, 2026
d86638f
Renamed managed policy
achirkin Feb 26, 2026
d6788f6
Add raft::resources for pinned and managed resources and the type-era…
achirkin Feb 26, 2026
e7bea48
Updated container policies
achirkin Feb 26, 2026
2514621
All but host memory resource are done
achirkin Feb 26, 2026
49735a5
Simplify the implementation
achirkin Feb 27, 2026
22b4048
Make the host container policy use the resource concept
achirkin Feb 27, 2026
557cc8c
Settle down with raft::mr::*et_default_host_resource()
achirkin Feb 27, 2026
cc7a4b0
Add some thread-safety
achirkin Feb 27, 2026
e77fe2a
Merge branch 'main' into fea-unify-memory-resources
achirkin Feb 27, 2026
866211e
C++17 backwards-compatibility
achirkin Feb 28, 2026
c171d84
Merge branch 'main' into fea-unify-memory-resources
achirkin Feb 28, 2026
268eb1b
newline
achirkin Feb 28, 2026
5c718d6
Add raft::mr::device_resource wrapper for cuda::mr::any_resource
achirkin Mar 1, 2026
c5ab9c4
Copy semantics and return resource refs
achirkin Mar 2, 2026
6af142e
Rework workspace resources to avoid nesting bridge layers
achirkin Mar 2, 2026
ece1990
Fix the argument order in tests
achirkin Mar 2, 2026
4dd256b
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 3, 2026
a26357d
Add explicit conversion through cuda::mr refs to rmm ref
achirkin Mar 3, 2026
2a90680
Switch from rmm host and host_device resource reference wrappers to r…
achirkin Mar 4, 2026
59c3793
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 4, 2026
3a40d22
Prefer rmm::mr::get_current_device_resource_ref() over rmm::mr::get_c…
achirkin Mar 4, 2026
cce4f45
Remove raft pinned and managed memory resources in favor of cuda::mr …
achirkin Mar 4, 2026
bc2518c
Tracking memory resources
achirkin Mar 4, 2026
ea2d7bf
Avoid direct resource -> rmm ref conversion to fix CI errors
achirkin Mar 5, 2026
b0c7b54
Merge branch 'main' into fea-tracking-memory-resources
achirkin Mar 9, 2026
550205b
Merge branch 'main' into fea-tracking-memory-resources
achirkin Mar 13, 2026
830ec3c
Merge branch 'main' into fea-tracking-memory-resources
achirkin Mar 14, 2026
e35c4bf
Update cpp/include/raft/mr/statistics_adaptor.hpp
achirkin Mar 16, 2026
1f7b67f
Make sure to record the last updates when stop() is called
achirkin Mar 16, 2026
90683f7
Improve clarity via docs and more explicit constructors
achirkin Mar 16, 2026
2abe848
Clarify the nvtx current_range definition
achirkin Mar 16, 2026
0f60d31
Enhance memory reporting: add local peak usage and total alloc/free f…
achirkin Mar 16, 2026
e10e02d
Move the thread_local range_name_stack_instance out of a getter funct…
achirkin Mar 17, 2026
1650bcf
Benchmark tracking overhead
achirkin Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# =============================================================================
# cmake-format: off
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
# cmake-format: on
# =============================================================================
Expand Down Expand Up @@ -67,7 +67,7 @@ function(ConfigureBench)
endfunction()

if(BUILD_PRIMS_BENCH)
ConfigureBench(NAME CORE_BENCH PATH core/bitset.cu core/copy.cu main.cpp)
ConfigureBench(NAME CORE_BENCH PATH core/bitset.cu core/copy.cu core/memory_tracking.cu main.cpp)

ConfigureBench(NAME UTIL_BENCH PATH util/popc.cu main.cpp)

Expand Down
129 changes: 129 additions & 0 deletions cpp/bench/prims/core/memory_tracking.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <common/benchmark.hpp>

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/memory_tracking_resources.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/resource_ref.hpp>

#include <unistd.h>

#include <chrono>
#include <cstdlib>
#include <filesystem>
#include <memory>
#include <vector>

namespace raft::bench::core {

struct tracking_inputs {
int num_allocs;
size_t alloc_size;
int64_t sample_rate_us;
bool batch;
};

struct tracking_overhead : public fixture {
tracking_overhead(const tracking_inputs& p) : fixture(true), params(p)
{
if (p.sample_rate_us >= 0) {
std::string tpl = (std::filesystem::temp_directory_path() / "raft_bench_XXXXXX").string();
int fd = mkstemp(tpl.data());
if (fd != -1) close(fd);
tmp_path_ = std::move(tpl);
tracked_res_.emplace(handle, tmp_path_, std::chrono::microseconds{p.sample_rate_us});
}
}

~tracking_overhead()
{
tracked_res_.reset();
if (!tmp_path_.empty()) { std::remove(tmp_path_.c_str()); }
}

void run_benchmark(::benchmark::State& state) override
{
state.counters["alloc_size"] = params.alloc_size;
state.counters["sample_rate_us"] = params.sample_rate_us;
state.counters["batch"] = params.batch;

run_allocs(state, tracked_res_ ? reinterpret_cast<raft::resources&>(*tracked_res_) : handle);

state.SetItemsProcessed(state.iterations() * params.num_allocs * 2);
}

private:
void run_allocs(::benchmark::State& state, raft::resources& res)
{
auto mr = raft::resource::get_workspace_resource_ref(res);
auto sv = raft::resource::get_cuda_stream(res);

if (params.batch) {
std::vector<void*> ptrs(params.num_allocs);
for (auto _ : state) {
auto t0 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < params.num_allocs; i++)
ptrs[i] = mr.allocate(sv, params.alloc_size);
for (int i = params.num_allocs - 1; i >= 0; i--)
mr.deallocate(sv, ptrs[i], params.alloc_size);
state.SetIterationTime(
std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count());
}
} else {
for (auto _ : state) {
auto t0 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < params.num_allocs; i++) {
void* p = mr.allocate(sv, params.alloc_size);
mr.deallocate(sv, p, params.alloc_size);
}
state.SetIterationTime(
std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count());
}
}
}

tracking_inputs params;
std::string tmp_path_;
std::optional<raft::memory_tracking_resources> tracked_res_ = std::nullopt;
};

const std::vector<tracking_inputs> inputs{
// ping-pong (isolates per-call overhead, pool recycles same block)
{10000, 256, -1, false},
{10000, 256, 0, false},
{10000, 256, 1, false},
{10000, 256, 10, false},
{10000, 256, 100, false},
{10000, 1 << 20, -1, false},
{10000, 1 << 20, 0, false},
{10000, 1 << 20, 1, false},
{10000, 1 << 20, 10, false},
{10000, 1 << 20, 100, false},
{1000, 1 << 26, -1, false},
{1000, 1 << 26, 0, false},
{1000, 1 << 26, 1, false},
{1000, 1 << 26, 10, false},
{1000, 1 << 26, 100, false},
// batch (allocate all, then deallocate all)
{10000, 256, -1, true},
{10000, 256, 0, true},
{10000, 256, 1, true},
{10000, 256, 10, true},
{10000, 256, 100, true},
{1000, 1 << 20, -1, true},
{1000, 1 << 20, 0, true},
{1000, 1 << 20, 1, true},
{1000, 1 << 20, 10, true},
{1000, 1 << 20, 100, true},
};

RAFT_BENCH_REGISTER(tracking_overhead, "", inputs);

} // namespace raft::bench::core
9 changes: 6 additions & 3 deletions cpp/include/raft/core/detail/nvtx.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/detail/nvtx_range_stack.hpp>

#include <rmm/cuda_stream_view.hpp>

#ifdef NVTX_ENABLED
Expand Down Expand Up @@ -146,6 +148,7 @@ inline void push_range_name(const char* name)
event_attrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
event_attrib.message.ascii = name;
nvtxDomainRangePushEx(domain_store<Domain>::value(), &event_attrib);
detail::range_name_stack_instance.push(name);
}

template <typename Domain, typename... Args>
Expand All @@ -168,12 +171,13 @@ inline void push_range(const char* format, Args... args)
template <typename Domain>
inline void pop_range()
{
detail::range_name_stack_instance.pop();
nvtxDomainRangePop(domain_store<Domain>::value());
}

} // namespace raft::common::nvtx::detail

#else // NVTX_ENABLED
#else // NVTX_ENABLED

namespace raft::common::nvtx::detail {

Expand All @@ -188,5 +192,4 @@ inline void pop_range()
}

} // namespace raft::common::nvtx::detail

#endif // NVTX_ENABLED
90 changes: 90 additions & 0 deletions cpp/include/raft/core/detail/nvtx_range_stack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <cstddef>
#include <memory>
#include <mutex>
#include <stack>
#include <string>
#include <utility>

namespace raft::common::nvtx {

namespace detail {
struct nvtx_range_name_stack;
} // namespace detail

/**
* Shared, read-only handle to the current NVTX range name of a specific thread
* (set internally by one thread, read publicly by zero or more threads).
*/
class current_range {
friend detail::nvtx_range_name_stack;

public:
/** Read the current range name and stack depth (safe to call from any thread). */
auto get() const -> std::pair<std::string, std::size_t>
{
std::lock_guard lock(mu_);
return {value_, depth_};
}

operator std::string() const
{
std::lock_guard lock(mu_);
return value_;
}

private:
mutable std::mutex mu_;
std::string value_;
std::size_t depth_{0};

void set(const char* name, std::size_t depth)
{
std::lock_guard lock(mu_);
value_ = name ? name : "";
depth_ = depth;
}
};

namespace detail {

struct nvtx_range_name_stack {
void push(const char* name)
{
stack_.emplace(name);
current_->set(name, stack_.size());
}

void pop()
{
if (!stack_.empty()) { stack_.pop(); }
current_->set(stack_.empty() ? nullptr : stack_.top().c_str(), stack_.size());
}

auto current() const -> std::shared_ptr<const current_range> { return current_; }

private:
std::stack<std::string> stack_{};
std::shared_ptr<current_range> current_{std::make_shared<current_range>()};
};

inline thread_local nvtx_range_name_stack range_name_stack_instance{};

} // namespace detail

/**
* Get a read-only handle to this thread's current NVTX range name.
* Pass the returned shared_ptr to another thread to read this thread's current NVTX range name at
* any time.
*/
inline auto thread_local_current_range() -> std::shared_ptr<const current_range>
{
return detail::range_name_stack_instance.current();
}

} // namespace raft::common::nvtx
Loading
Loading