Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions c/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ cuvsError_t cuvsIvfPqIndexGetPqLen(cuvsIvfPqIndex_t index, int64_t* pq_len);
*/
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManagedTensor* centers);

/**
* @brief Get the padded cluster centers [n_lists, dim_ext]
* where dim_ext = round_up(dim + 1, 8)
*
* This returns the full padded centers as a contiguous array, suitable for
* use with cuvsIvfPqBuildPrecomputed.
*
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
* @param[out] centers Output tensor that will be populated with a non-owning view of the data
* @return cuvsError_t
*/
cuvsError_t cuvsIvfPqIndexGetCentersPadded(cuvsIvfPqIndex_t index, DLManagedTensor* centers);

/**
* @brief Get the PQ cluster centers
*
Expand All @@ -290,6 +303,28 @@ cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManagedTensor* ce
*/
cuvsError_t cuvsIvfPqIndexGetPqCenters(cuvsIvfPqIndex_t index, DLManagedTensor* pq_centers);

/**
* @brief Get the rotated cluster centers [n_lists, rot_dim]
* where rot_dim = pq_len * pq_dim
*
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
* @param[out] centers_rot Output tensor that will be populated with a non-owning view of the data
* @return cuvsError_t
*/
cuvsError_t cuvsIvfPqIndexGetCentersRot(cuvsIvfPqIndex_t index, DLManagedTensor* centers_rot);

/**
* @brief Get the rotation matrix [rot_dim, dim]
* Transform matrix (original space -> rotated padded space)
*
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
* @param[out] rotation_matrix Output tensor that will be populated with a non-owning view of the
* data
* @return cuvsError_t
*/
cuvsError_t cuvsIvfPqIndexGetRotationMatrix(cuvsIvfPqIndex_t index,
DLManagedTensor* rotation_matrix);

/**
* @brief Get the sizes of each list
*
Expand Down Expand Up @@ -389,6 +424,44 @@ cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
cuvsIvfPqIndexParams_t params,
DLManagedTensor* dataset,
cuvsIvfPqIndex_t index);

/**
* @brief Build a view-type IVF-PQ index from device memory precomputed centroids and codebook.
*
* This function creates a non-owning index that stores a reference to the provided device data.
* All parameters must be provided with correct extents. The caller is responsible for ensuring
* the lifetime of the input data exceeds the lifetime of the returned index.
*
* The index_params must be consistent with the provided matrices. Specifically:
* - index_params.codebook_kind determines the expected shape of pq_centers
* - index_params.metric will be stored in the index
* - index_params.conservative_memory_allocation will be stored in the index
* The function will verify consistency between index_params, dim, and the matrix extents.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] params cuvsIvfPqIndexParams_t used to configure the index (must be consistent with
* matrices)
* @param[in] dim dimensionality of the input data
* @param[in] pq_centers PQ codebook on device memory with required shape:
* - codebook_kind PER_SUBSPACE: [pq_dim, pq_len, pq_book_size]
* - codebook_kind PER_CLUSTER: [n_lists, pq_len, pq_book_size]
* @param[in] centers Cluster centers in the original space [n_lists, dim_ext]
* where dim_ext = round_up(dim + 1, 8)
* @param[in] centers_rot Rotated cluster centers [n_lists, rot_dim]
* where rot_dim = pq_len * pq_dim
* @param[in] rotation_matrix Transform matrix (original space -> rotated padded space) [rot_dim,
* dim]
* @param[out] index cuvsIvfPqIndex_t Newly built view-type IVF-PQ index
* @return cuvsError_t
*/
cuvsError_t cuvsIvfPqBuildPrecomputed(cuvsResources_t res,
cuvsIvfPqIndexParams_t params,
uint32_t dim,
DLManagedTensor* pq_centers,
DLManagedTensor* centers,
DLManagedTensor* centers_rot,
DLManagedTensor* rotation_matrix,
cuvsIvfPqIndex_t index);
/**
* @}
*/
Expand Down
108 changes: 108 additions & 0 deletions c/src/neighbors/ivf_pq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,35 @@ void _get_centers(cuvsIvfPqIndex index, DLManagedTensor* output)
cuvs::core::to_dlpack(strided_centers, output);
}

template <typename IdxT>
void _get_centers_padded(cuvsIvfPqIndex index, DLManagedTensor* output)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
// Return the full padded centers [n_lists, dim_ext] as a contiguous array
cuvs::core::to_dlpack(index_ptr->centers(), output);
}

template <typename IdxT>
void _get_pq_centers(cuvsIvfPqIndex index, DLManagedTensor* centers)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
cuvs::core::to_dlpack(index_ptr->pq_centers(), centers);
}

template <typename IdxT>
void _get_centers_rot(cuvsIvfPqIndex index, DLManagedTensor* centers_rot)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
cuvs::core::to_dlpack(index_ptr->centers_rot(), centers_rot);
}

template <typename IdxT>
void _get_rotation_matrix(cuvsIvfPqIndex index, DLManagedTensor* rotation_matrix)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
cuvs::core::to_dlpack(index_ptr->rotation_matrix(), rotation_matrix);
}

template <typename IdxT>
void _get_list_sizes(cuvsIvfPqIndex index, DLManagedTensor* list_sizes)
{
Expand Down Expand Up @@ -355,6 +377,12 @@ extern "C" cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
auto vectors = new_vectors->dl_tensor;

// Set the index dtype if not already set (e.g., for view-type indices built from precomputed data)
if (index->dtype.code == 0 && index->dtype.bits == 0) {
index->dtype.code = vectors.dtype.code;
index->dtype.bits = vectors.dtype.bits;
}

if (vectors.dtype.code == kDLFloat && vectors.dtype.bits == 32) {
_extend<float, int64_t>(res, new_vectors, new_indices, *index);
} else if (vectors.dtype.code == kDLFloat && vectors.dtype.bits == 16) {
Expand Down Expand Up @@ -422,12 +450,92 @@ extern "C" cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManage
return cuvs::core::translate_exceptions([=] { _get_centers<int64_t>(*index, centers); });
}

extern "C" cuvsError_t cuvsIvfPqIndexGetCentersPadded(cuvsIvfPqIndex_t index,
DLManagedTensor* centers)
{
return cuvs::core::translate_exceptions([=] { _get_centers_padded<int64_t>(*index, centers); });
}

extern "C" cuvsError_t cuvsIvfPqIndexGetPqCenters(cuvsIvfPqIndex_t index,
DLManagedTensor* pq_centers)
{
return cuvs::core::translate_exceptions([=] { _get_pq_centers<int64_t>(*index, pq_centers); });
}

extern "C" cuvsError_t cuvsIvfPqIndexGetCentersRot(cuvsIvfPqIndex_t index,
DLManagedTensor* centers_rot)
{
return cuvs::core::translate_exceptions([=] { _get_centers_rot<int64_t>(*index, centers_rot); });
}

extern "C" cuvsError_t cuvsIvfPqIndexGetRotationMatrix(cuvsIvfPqIndex_t index,
DLManagedTensor* rotation_matrix)
{
return cuvs::core::translate_exceptions(
[=] { _get_rotation_matrix<int64_t>(*index, rotation_matrix); });
}

extern "C" cuvsError_t cuvsIvfPqBuildPrecomputed(cuvsResources_t res,
cuvsIvfPqIndexParams_t params,
uint32_t dim,
DLManagedTensor* pq_centers_tensor,
DLManagedTensor* centers_tensor,
DLManagedTensor* centers_rot_tensor,
DLManagedTensor* rotation_matrix_tensor,
cuvsIvfPqIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);

auto build_params = cuvs::neighbors::ivf_pq::index_params();
convert_c_index_params(*params, &build_params);

// Verify all tensors are on device
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(pq_centers_tensor->dl_tensor),
"pq_centers should have device compatible memory");
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(centers_tensor->dl_tensor),
"centers should have device compatible memory");
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(centers_rot_tensor->dl_tensor),
"centers_rot should have device compatible memory");
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(rotation_matrix_tensor->dl_tensor),
"rotation_matrix should have device compatible memory");

// Verify all tensors are float32
auto& pq_centers_dl = pq_centers_tensor->dl_tensor;
auto& centers_dl = centers_tensor->dl_tensor;
auto& centers_rot_dl = centers_rot_tensor->dl_tensor;
auto& rotation_matrix_dl = rotation_matrix_tensor->dl_tensor;

RAFT_EXPECTS(pq_centers_dl.dtype.code == kDLFloat && pq_centers_dl.dtype.bits == 32,
"pq_centers must be float32");
RAFT_EXPECTS(centers_dl.dtype.code == kDLFloat && centers_dl.dtype.bits == 32,
"centers must be float32");
RAFT_EXPECTS(centers_rot_dl.dtype.code == kDLFloat && centers_rot_dl.dtype.bits == 32,
"centers_rot must be float32");
RAFT_EXPECTS(rotation_matrix_dl.dtype.code == kDLFloat && rotation_matrix_dl.dtype.bits == 32,
"rotation_matrix must be float32");

// Convert DLPack tensors to mdspan views
using pq_centers_mdspan_type = raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major>;
using matrix_mdspan_type = raft::device_matrix_view<const float, uint32_t, raft::row_major>;

auto pq_centers_mds = cuvs::core::from_dlpack<pq_centers_mdspan_type>(pq_centers_tensor);
auto centers_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(centers_tensor);
auto centers_rot_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(centers_rot_tensor);
auto rotation_matrix_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(rotation_matrix_tensor);

// Build the index
auto* idx = new cuvs::neighbors::ivf_pq::index<int64_t>(
cuvs::neighbors::ivf_pq::build(
*res_ptr, build_params, dim, pq_centers_mds, centers_mds, centers_rot_mds, rotation_matrix_mds));

index->addr = reinterpret_cast<uintptr_t>(idx);
// Leave dtype unset (0) - it will be set when extend() is called with actual data
index->dtype.code = 0;
index->dtype.bits = 0;
});
}

extern "C" cuvsError_t cuvsIvfPqIndexGetListSizes(cuvsIvfPqIndex_t index,
DLManagedTensor* list_sizes)
{
Expand Down
4 changes: 3 additions & 1 deletion python/cuvs/cuvs/neighbors/ivf_pq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0


Expand All @@ -7,6 +7,7 @@
IndexParams,
SearchParams,
build,
build_precomputed,
extend,
load,
save,
Expand All @@ -18,6 +19,7 @@
"IndexParams",
"SearchParams",
"build",
"build_precomputed",
"extend",
"load",
"save",
Expand Down
18 changes: 18 additions & 0 deletions python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,21 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index,
DLManagedTensor * centers)

cuvsError_t cuvsIvfPqIndexGetCentersPadded(cuvsIvfPqIndex_t index,
DLManagedTensor * centers)

cuvsError_t cuvsIvfPqIndexGetListSizes(cuvsIvfPqIndex_t index,
DLManagedTensor * list_sizes)

cuvsError_t cuvsIvfPqIndexGetPqCenters(cuvsIvfPqIndex_t index,
DLManagedTensor * centers)

cuvsError_t cuvsIvfPqIndexGetCentersRot(cuvsIvfPqIndex_t index,
DLManagedTensor * centers_rot)

cuvsError_t cuvsIvfPqIndexGetRotationMatrix(cuvsIvfPqIndex_t index,
DLManagedTensor * rotation_matrix)

cuvsError_t cuvsIvfPqIndexUnpackContiguousListData(cuvsResources_t res,
cuvsIvfPqIndex_t index,
DLManagedTensor* out,
Expand All @@ -113,6 +122,15 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:
DLManagedTensor* dataset,
cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqBuildPrecomputed(cuvsResources_t res,
cuvsIvfPqIndexParams_t params,
uint32_t dim,
DLManagedTensor* pq_centers,
DLManagedTensor* centers,
DLManagedTensor* centers_rot,
DLManagedTensor* rotation_matrix,
cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
cuvsIvfPqSearchParams* params,
cuvsIvfPqIndex_t index,
Expand Down
Loading