Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
45e2d49
Rename device_uvector_policy -> device_container_policy and add non-i…
achirkin Feb 26, 2026
65d4570
Declare the new resources in raft handle
achirkin Feb 26, 2026
d86638f
Renamed managed policy
achirkin Feb 26, 2026
d6788f6
Add raft::resources for pinned and managed resources and the type-era…
achirkin Feb 26, 2026
e7bea48
Updated container policies
achirkin Feb 26, 2026
2514621
All but host memory resource are done
achirkin Feb 26, 2026
49735a5
Simplify the implementation
achirkin Feb 27, 2026
22b4048
Make the host container policy use the resource concept
achirkin Feb 27, 2026
557cc8c
Settle down with raft::mr::*et_default_host_resource()
achirkin Feb 27, 2026
cc7a4b0
Add some thread-safety
achirkin Feb 27, 2026
e77fe2a
Merge branch 'main' into fea-unify-memory-resources
achirkin Feb 27, 2026
866211e
C++17 backwards-compatibility
achirkin Feb 28, 2026
c171d84
Merge branch 'main' into fea-unify-memory-resources
achirkin Feb 28, 2026
268eb1b
newline
achirkin Feb 28, 2026
5c718d6
Add raft::mr::device_resource wrapper for cuda::mr::any_resource
achirkin Mar 1, 2026
c5ab9c4
Copy semantics and return resource refs
achirkin Mar 2, 2026
6af142e
Rework workspace resources to avoid nesting bridge layers
achirkin Mar 2, 2026
ece1990
Fix the argument order in tests
achirkin Mar 2, 2026
4dd256b
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 3, 2026
a26357d
Add explicit conversion through cuda::mr refs to rmm ref
achirkin Mar 3, 2026
2a90680
Switch from rmm host and host_device resource reference wrappers to r…
achirkin Mar 4, 2026
59c3793
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 4, 2026
3a40d22
Prefer rmm::mr::get_current_device_resource_ref() over rmm::mr::get_c…
achirkin Mar 4, 2026
cce4f45
Remove raft pinned and managed memory resources in favor of cuda::mr …
achirkin Mar 4, 2026
a3fe671
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 9, 2026
d01efbf
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 10, 2026
f6fc1f1
Update cpp/include/raft/mr/mmap_memory_resource.hpp
achirkin Mar 12, 2026
e95c67c
Update cpp/include/raft/mr/mmap_memory_resource.hpp
achirkin Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (2019) Sandia Corporation
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
*/
/*
Expand Down Expand Up @@ -134,7 +134,7 @@ class device_uvector {
* @brief A container policy for device mdarray.
*/
template <typename ElementType>
class device_uvector_policy {
class device_container_policy {
public:
using element_type = ElementType;
using container_type = device_uvector<element_type>;
Expand All @@ -153,8 +153,8 @@ class device_uvector_policy {
return container_type(n, resource::get_cuda_stream(res), mr_);
}

constexpr device_uvector_policy() = default;
explicit device_uvector_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {}
constexpr device_container_policy() = default;
explicit device_container_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {}

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
Expand All @@ -170,7 +170,7 @@ class device_uvector_policy {
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()};
rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource_ref()};
};

} // namespace raft
Expand All @@ -189,7 +189,7 @@ template <typename T>
using device_uvector = detail::fail_container<T>;

template <typename ElementType>
using device_uvector_policy = detail::fail_container_policy<ElementType>;
using device_container_policy = detail::fail_container_policy<ElementType>;

} // namespace raft
#endif
10 changes: 5 additions & 5 deletions cpp/include/raft/core/device_coo_matrix.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -29,7 +29,7 @@ using device_coordinate_structure_view = coordinate_structure_view<RowType, ColT
template <typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
template <typename T> typename ContainerPolicy = device_container_policy>
using device_coordinate_structure =
coordinate_structure<RowType, ColType, NZType, true, ContainerPolicy>;

Expand All @@ -43,7 +43,7 @@ template <typename ElementType,
typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy,
template <typename T> typename ContainerPolicy = device_container_policy,
SparsityType sparsity_type = SparsityType::OWNING>
using device_coo_matrix =
coo_matrix<ElementType, RowType, ColType, NZType, true, ContainerPolicy, sparsity_type>;
Expand All @@ -55,15 +55,15 @@ template <typename ElementType,
typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
template <typename T> typename ContainerPolicy = device_container_policy>
using device_sparsity_owning_coo_matrix =
coo_matrix<ElementType, RowType, ColType, NZType, true, ContainerPolicy>;

template <typename ElementType,
typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
template <typename T> typename ContainerPolicy = device_container_policy>
using device_sparsity_preserving_coo_matrix = coo_matrix<ElementType,
RowType,
ColType,
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/raft/core/device_csr_matrix.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -33,7 +33,7 @@ using device_compressed_structure_view =
template <typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
template <typename T> typename ContainerPolicy = device_container_policy>
using device_compressed_structure =
compressed_structure<IndptrType, IndicesType, NZType, true, ContainerPolicy>;

Expand All @@ -47,7 +47,7 @@ template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy,
template <typename T> typename ContainerPolicy = device_container_policy,
SparsityType sparsity_type = SparsityType::OWNING>
using device_csr_matrix =
csr_matrix<ElementType, IndptrType, IndicesType, NZType, true, ContainerPolicy, sparsity_type>;
Expand All @@ -59,7 +59,7 @@ template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
template <typename T> typename ContainerPolicy = device_container_policy>
using device_sparsity_owning_csr_matrix =
csr_matrix<ElementType, IndptrType, IndicesType, NZType, true, ContainerPolicy>;

Expand All @@ -70,7 +70,7 @@ template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
template <typename T> typename ContainerPolicy = device_container_policy>
using device_sparsity_preserving_csr_matrix = csr_matrix<ElementType,
IndptrType,
IndicesType,
Expand Down
26 changes: 22 additions & 4 deletions cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -26,7 +26,7 @@ namespace raft {
template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename ContainerPolicy = device_uvector_policy<ElementType>>
typename ContainerPolicy = device_container_policy<ElementType>>
using device_mdarray =
mdarray<ElementType, Extents, LayoutPolicy, device_accessor<ContainerPolicy>>;

Expand Down Expand Up @@ -130,12 +130,30 @@ auto make_device_matrix(raft::resources const& handle, IndexType n_rows, IndexTy
}

/**
* @brief Create a device scalar from v.
* @brief Create an uninitialized device scalar.
*
* @tparam ElementType the data type of the scalar element
* @tparam IndexType the index type of the extents
* @param[in] handle raft handle for managing expensive cuda resources
* @param[in] v scalar to wrap on device
* @return raft::device_scalar
*/
template <typename ElementType, typename IndexType = std::uint32_t>
auto make_device_scalar(raft::resources const& handle)
{
scalar_extent<IndexType> extents;
using policy_t = typename device_scalar<ElementType, IndexType>::container_policy_type;
policy_t policy{};
return device_scalar<ElementType, IndexType>{handle, extents, policy};
}

/**
* @brief Create a device scalar from v
* (async copy in the resource-provided stream).
*
* @tparam ElementType the data type of the scalar element
* @tparam IndexType the index type of the extents
* @param[in] handle raft handle for managing expensive cuda resources
* @param[in] v scalar to copy to device
* @return raft::device_scalar
*/
template <typename ElementType, typename IndexType = std::uint32_t>
Expand Down
52 changes: 31 additions & 21 deletions cpp/include/raft/core/host_container_policy.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (2019) Sandia Corporation
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2019 Sandia Corporation
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
*/
/*
Expand All @@ -14,18 +14,23 @@

#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

#include <memory_resource>
#include <raft/mr/host_memory_resource.hpp>

namespace raft {

/**
* @brief A container using the std::pmr::memory_resource for allocations.
* @brief A container backed by a host-accessible cuda::mr::synchronous_resource.
*
* @tparam T element type
* @tparam MR a type satisfying cuda::mr::synchronous_resource_with<cuda::mr::host_accessible>
*/
template <typename T>
struct host_container {
template <typename T, typename MR>
#ifdef __cpp_concepts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think RAFT is using C++20 now so it should be safe to use requires without the #ifdef guard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately some components of cuvs still use C++17 and it breaks if I remove the #ifdef in this header. I figured, I'd keep it here to keep cuvs passing CI without changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should get cuVS updated to C++20, RMM will be requiring C++20 soon.

requires cuda::mr::synchronous_resource_with<MR, cuda::mr::host_accessible>
#endif
struct host_container {
static_assert(cuda::mr::synchronous_resource_with<MR, cuda::mr::host_accessible>,
"MR must be a host-accessible synchronous resource");
using value_type = std::remove_cv_t<T>;
using size_type = std::size_t;

Expand All @@ -39,25 +44,26 @@ struct host_container {
using const_iterator = const_pointer;

private:
std::pmr::memory_resource* mr_;
MR mr_;
size_type bytesize_ = 0;
value_type* data_ = nullptr;

public:
host_container(size_type count, std::pmr::memory_resource* mr = nullptr)
: mr_(mr == nullptr ? std::pmr::get_default_resource() : mr),
host_container(size_type count,
MR mr) // NB: pass by value, as we expect a resource_ref anyway
: mr_(std::move(mr)),
bytesize_(sizeof(value_type) * count),
data_(bytesize_ > 0 ? static_cast<pointer>(mr_->allocate(bytesize_)) : nullptr)
data_(bytesize_ > 0 ? static_cast<pointer>(mr_.allocate_sync(bytesize_)) : nullptr)
{
}

~host_container() noexcept
{
if (bytesize_ > 0 && data_ != nullptr) { mr_->deallocate(data_, bytesize_); }
if (bytesize_ > 0 && data_ != nullptr) { mr_.deallocate_sync(data_, bytesize_); }
}

host_container(host_container&& other) noexcept
: mr_{std::exchange(other.mr_, nullptr)},
: mr_{std::move(other.mr_)},
bytesize_{std::exchange(other.bytesize_, 0)},
data_{std::exchange(other.data_, nullptr)}
{
Expand Down Expand Up @@ -103,25 +109,29 @@ struct host_container {
};

/**
* @brief A container policy for host mdarray.
* @brief Container policy for host mdarray.
*
* Defaults to raft::mr::get_default_host_resource().
*/
template <typename ElementType>
class host_container_policy {
public:
using element_type = ElementType;
using container_type = host_container<element_type>;
using container_type = host_container<element_type, raft::mr::host_resource_ref>;
using pointer = typename container_type::pointer;
using const_pointer = typename container_type::const_pointer;
using reference = typename container_type::reference;
using const_reference = typename container_type::const_reference;
using accessor_policy = cuda::std::default_accessor<element_type>;
using const_accessor_policy = cuda::std::default_accessor<element_type const>;

public:
auto create(raft::resources const&, size_t n) -> container_type { return container_type(n, mr_); }
host_container_policy() = default;
explicit host_container_policy(raft::mr::host_resource_ref ref) noexcept : ref_(ref) {}

constexpr host_container_policy() noexcept = default;
explicit host_container_policy(std::pmr::memory_resource* mr) noexcept : mr_(mr) {}
auto create(raft::resources const&, size_t n) -> container_type
{
return container_type(n, ref_);
}

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
Expand All @@ -137,7 +147,7 @@ class host_container_policy {
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
std::pmr::memory_resource* mr_{std::pmr::get_default_resource()};
raft::mr::host_resource_ref ref_ = raft::mr::get_default_host_resource();
};

} // namespace raft
40 changes: 31 additions & 9 deletions cpp/include/raft/core/host_mdarray.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -74,7 +74,7 @@ template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_host_mdarray(raft::resources& res, extents<IndexType, Extents...> exts)
auto make_host_mdarray(raft::resources const& res, extents<IndexType, Extents...> exts)
{
using mdarray_t = host_mdarray<ElementType, decltype(exts), LayoutPolicy>;

Expand All @@ -85,21 +85,21 @@ auto make_host_mdarray(raft::resources& res, extents<IndexType, Extents...> exts
}

/**
* @brief Create a host mdarray.
* @brief Create a host mdarray with a custom memory resource.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param res raft::resources
* @param mr std::pmr::memory_resource used for allocating the memory for the array
* @param mr host memory resource reference used for allocating the memory for the array
* @param exts dimensionality of the array (series of integers)
* @return raft::device_mdarray
* @return raft::host_mdarray
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_host_mdarray(raft::resources const& res,
std::pmr::memory_resource* mr,
raft::mr::host_resource_ref mr,
extents<IndexType, Extents...> exts)
{
using mdarray_t = host_mdarray<ElementType, decltype(exts), LayoutPolicy>;
Expand Down Expand Up @@ -154,7 +154,7 @@ auto make_host_mdarray(extents<IndexType, Extents...> exts)
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_host_matrix(raft::resources& res, IndexType n_rows, IndexType n_cols)
auto make_host_matrix(raft::resources const& res, IndexType n_rows, IndexType n_cols)
{
return make_host_mdarray<ElementType, IndexType, LayoutPolicy>(
res, make_extents<IndexType>(n_rows, n_cols));
Expand All @@ -181,6 +181,28 @@ auto make_host_matrix(IndexType n_rows, IndexType n_cols)
make_extents<IndexType>(n_rows, n_cols));
}

/**
* @ingroup host_mdarray_factories
* @brief Create an uninitialized host scalar.
*
* @tparam ElementType the data type of the scalar element
* @tparam IndexType the index type of the extents
* @param[in] res raft handle for managing expensive resources
* @return raft::host_scalar
*/
template <typename ElementType, typename IndexType = std::uint32_t>
auto make_host_scalar(raft::resources const& res)
{
// FIXME(jiamingy): We can optimize this by using std::array as container policy, which
// requires some more compile time dispatching. This is enabled in the ref impl but
// hasn't been ported here yet.
scalar_extent<IndexType> extents;
using policy_t = typename host_scalar<ElementType, IndexType>::container_policy_type;
policy_t policy;
auto scalar = host_scalar<ElementType, IndexType>{res, extents, policy};
return scalar;
}

/**
* @ingroup host_mdarray_factories
* @brief Create a host scalar from v.
Expand All @@ -192,7 +214,7 @@ auto make_host_matrix(IndexType n_rows, IndexType n_cols)
* @return raft::host_scalar
*/
template <typename ElementType, typename IndexType = std::uint32_t>
auto make_host_scalar(raft::resources& res, ElementType const& v)
auto make_host_scalar(raft::resources const& res, ElementType const& v)
{
// FIXME(jiamingy): We can optimize this by using std::array as container policy, which
// requires some more compile time dispatching. This is enabled in the ref impl but
Expand Down Expand Up @@ -244,7 +266,7 @@ auto make_host_scalar(ElementType const& v)
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector(raft::resources& res, IndexType n)
auto make_host_vector(raft::resources const& res, IndexType n)
{
return make_host_mdarray<ElementType, IndexType, LayoutPolicy>(res, make_extents<IndexType>(n));
}
Expand Down
Loading
Loading