Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/composite/merge.cpp
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/cagra.cpp>
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
src/neighbors/ivf_common.cu
src/neighbors/ivf_flat_index.cpp
src/neighbors/ivf_flat/ivf_flat_build_extend_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_build_extend_half_int64_t.cu
Expand Down
64 changes: 64 additions & 0 deletions cpp/src/neighbors/ivf_common.cu
Copy link
Member

@divyegala divyegala Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel needs to be launched in the same TU as which it is defined. We can (but should ideally avoid) pass the pointer around to other TUs but they shouldn't be attempting to launch the kernel.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <raft/util/pow2_utils.cuh>

#include <cub/cub.cuh>

namespace cuvs::neighbors::ivf::detail {

/**
* For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that
* in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total
* number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples.
*/
template <int BlockDim>
__launch_bounds__(BlockDim) RAFT_KERNEL
calc_chunk_indices_kernel(uint32_t n_probes,
const uint32_t* cluster_sizes, // [n_clusters]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t* n_samples // [n_queries]
)
{
using block_scan = cub::BlockScan<uint32_t, BlockDim>;
__shared__ typename block_scan::TempStorage shm;

// locate the query data
clusters_to_probe += n_probes * blockIdx.x;
chunk_indices += n_probes * blockIdx.x;

// block scan
const uint32_t n_probes_aligned = raft::Pow2<BlockDim>::roundUp(n_probes);
uint32_t total = 0;
for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) {
auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u;
auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u;
if (threadIdx.x == 0) { chunk += total; }
block_scan(shm).InclusiveSum(chunk, chunk, total);
__syncthreads();
if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; }
}
// save the total size
if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; }
}

/**
* Returns a pointer to calc_chunk_indices_kernel for the given BlockDim.
*/
template <int BlockDim>
void* get_calc_chunk_indices_kernel()
{
return reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim>);
}

template void* get_calc_chunk_indices_kernel<32>();
template void* get_calc_chunk_indices_kernel<64>();
template void* get_calc_chunk_indices_kernel<128>();
template void* get_calc_chunk_indices_kernel<256>();
template void* get_calc_chunk_indices_kernel<512>();
template void* get_calc_chunk_indices_kernel<1024>();

} // namespace cuvs::neighbors::ivf::detail
44 changes: 12 additions & 32 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,20 @@ struct dummy_block_sort_t {
};

/**
* For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that
* in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total
* number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples.
* Returns a pointer to calc_chunk_indices_kernel for the given BlockDim.
* The kernel is defined and the pointer is taken in ivf_common.cu to comply
* with CUDA whole compilation rules. See
* https://developer.nvidia.com/blog/cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
*/
template <int BlockDim>
__launch_bounds__(BlockDim) RAFT_KERNEL
calc_chunk_indices_kernel(uint32_t n_probes,
const uint32_t* cluster_sizes, // [n_clusters]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t* n_samples // [n_queries]
)
{
using block_scan = cub::BlockScan<uint32_t, BlockDim>;
__shared__ typename block_scan::TempStorage shm;

// locate the query data
clusters_to_probe += n_probes * blockIdx.x;
chunk_indices += n_probes * blockIdx.x;
void* get_calc_chunk_indices_kernel();

// block scan
const uint32_t n_probes_aligned = raft::Pow2<BlockDim>::roundUp(n_probes);
uint32_t total = 0;
for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) {
auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u;
auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u;
if (threadIdx.x == 0) { chunk += total; }
block_scan(shm).InclusiveSum(chunk, chunk, total);
__syncthreads();
if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; }
}
// save the total size
if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; }
}
extern template void* get_calc_chunk_indices_kernel<32>();
extern template void* get_calc_chunk_indices_kernel<64>();
extern template void* get_calc_chunk_indices_kernel<128>();
extern template void* get_calc_chunk_indices_kernel<256>();
extern template void* get_calc_chunk_indices_kernel<512>();
extern template void* get_calc_chunk_indices_kernel<1024>();

struct calc_chunk_indices {
public:
Expand Down Expand Up @@ -97,7 +77,7 @@ struct calc_chunk_indices {
if constexpr (BlockDim >= raft::WarpSize * 2) {
if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); }
}
return {reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim>),
return {get_calc_chunk_indices_kernel<BlockDim>(),
dim3(BlockDim, 1, 1),
dim3(n_queries, 1, 1),
n_probes};
Expand Down
Loading