Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/composite/merge.cpp
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/cagra.cpp>
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
src/neighbors/ivf_common.cu
src/neighbors/ivf_flat_index.cpp
src/neighbors/ivf_flat/ivf_flat_build_extend_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_build_extend_half_int64_t.cu
Expand Down
75 changes: 75 additions & 0 deletions cpp/src/neighbors/ivf_common.cu
Copy link
Member

@divyegala divyegala Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel needs to be launched in the same TU as which it is defined. We can (but should ideally avoid) pass the pointer around to other TUs but they shouldn't be attempting to launch the kernel.

Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "ivf_common.cuh"

#include <raft/core/cudart_utils.hpp>
#include <raft/util/pow2_utils.cuh>

#include <cub/cub.cuh>

namespace cuvs::neighbors::ivf::detail {

/**
* For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that
* in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total
* number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples.
*/
template <int BlockDim>
__launch_bounds__(BlockDim) RAFT_KERNEL
calc_chunk_indices_kernel(uint32_t n_probes,
const uint32_t* cluster_sizes, // [n_clusters]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t* n_samples // [n_queries]
)
{
using block_scan = cub::BlockScan<uint32_t, BlockDim>;
__shared__ typename block_scan::TempStorage shm;

// locate the query data
clusters_to_probe += n_probes * blockIdx.x;
chunk_indices += n_probes * blockIdx.x;

// block scan
const uint32_t n_probes_aligned = raft::Pow2<BlockDim>::roundUp(n_probes);
uint32_t total = 0;
for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) {
auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u;
auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u;
if (threadIdx.x == 0) { chunk += total; }
block_scan(shm).InclusiveSum(chunk, chunk, total);
__syncthreads();
if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; }
}
// save the total size
if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; }
}

void calc_chunk_indices::configured::operator()(const uint32_t* cluster_sizes,
const uint32_t* clusters_to_probe,
uint32_t* chunk_indices,
uint32_t* n_samples,
rmm::cuda_stream_view stream)
{
void* kernel = nullptr;
switch (block_dim.x) {
case 32: kernel = reinterpret_cast<void*>(calc_chunk_indices_kernel<32>); break;
case 64: kernel = reinterpret_cast<void*>(calc_chunk_indices_kernel<64>); break;
case 128: kernel = reinterpret_cast<void*>(calc_chunk_indices_kernel<128>); break;
case 256: kernel = reinterpret_cast<void*>(calc_chunk_indices_kernel<256>); break;
case 512: kernel = reinterpret_cast<void*>(calc_chunk_indices_kernel<512>); break;
case 1024: kernel = reinterpret_cast<void*>(calc_chunk_indices_kernel<1024>); break;
default:
RAFT_FAIL("Unsupported block dimension for calc_chunk_indices::configured() : %d",
block_dim.x);
}

void* args[] = // NOLINT
{&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples};
RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream));
}

} // namespace cuvs::neighbors::ivf::detail
63 changes: 13 additions & 50 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -30,59 +30,25 @@ struct dummy_block_sort_t {
};

/**
* For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that
* in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total
* number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples.
* Struct to configure and launch calc_chunk_indices_kernel.
*
* Both configure() and operator() are defined in ivf_common.cu to comply
* with CUDA whole compilation rules - the kernel pointer must be obtained
* and used within the same translation unit. See
* https://developer.nvidia.com/blog/cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
*/
template <int BlockDim>
__launch_bounds__(BlockDim) RAFT_KERNEL
calc_chunk_indices_kernel(uint32_t n_probes,
const uint32_t* cluster_sizes, // [n_clusters]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t* n_samples // [n_queries]
)
{
using block_scan = cub::BlockScan<uint32_t, BlockDim>;
__shared__ typename block_scan::TempStorage shm;

// locate the query data
clusters_to_probe += n_probes * blockIdx.x;
chunk_indices += n_probes * blockIdx.x;

// block scan
const uint32_t n_probes_aligned = raft::Pow2<BlockDim>::roundUp(n_probes);
uint32_t total = 0;
for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) {
auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u;
auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u;
if (threadIdx.x == 0) { chunk += total; }
block_scan(shm).InclusiveSum(chunk, chunk, total);
__syncthreads();
if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; }
}
// save the total size
if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; }
}

struct calc_chunk_indices {
public:
struct configured {
void* kernel;
dim3 block_dim;
dim3 grid_dim;
uint32_t n_probes;

inline void operator()(const uint32_t* cluster_sizes,
const uint32_t* clusters_to_probe,
uint32_t* chunk_indices,
uint32_t* n_samples,
rmm::cuda_stream_view stream)
{
void* args[] = // NOLINT
{&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples};
RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream));
}
void operator()(const uint32_t* cluster_sizes,
const uint32_t* clusters_to_probe,
uint32_t* chunk_indices,
uint32_t* n_samples,
rmm::cuda_stream_view stream);
};

static inline auto configure(uint32_t n_probes, uint32_t n_queries) -> configured
Expand All @@ -97,10 +63,7 @@ struct calc_chunk_indices {
if constexpr (BlockDim >= raft::WarpSize * 2) {
if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); }
}
return {reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim>),
dim3(BlockDim, 1, 1),
dim3(n_queries, 1, 1),
n_probes};
return {dim3(BlockDim, 1, 1), dim3(n_queries, 1, 1), n_probes};
}
};

Expand Down