Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -76,11 +76,11 @@ template <typename T>
RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size_bitshift) -> T
{
switch (team_size_bitshift) {
case 5: x += raft::shfl_xor(x, 16);
case 4: x += raft::shfl_xor(x, 8);
case 3: x += raft::shfl_xor(x, 4);
case 2: x += raft::shfl_xor(x, 2);
case 1: x += raft::shfl_xor(x, 1);
case 5: x += raft::shfl_xor(x, 16); [[fallthrough]];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

case 4: x += raft::shfl_xor(x, 8); [[fallthrough]];
case 3: x += raft::shfl_xor(x, 4); [[fallthrough]];
case 2: x += raft::shfl_xor(x, 2); [[fallthrough]];
case 1: x += raft::shfl_xor(x, 1); [[fallthrough]];
default: return x;
}
}
Expand All @@ -106,7 +106,6 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
{
const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem();
const auto max_i = raft::round_up_safe<uint32_t>(num_pickup, warp_size >> team_size_bits);
const auto compute_distance = dataset_desc.compute_distance_impl;

for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) {
const bool valid_i = (i < num_pickup);
Expand All @@ -115,7 +114,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
DistanceT best_norm2_team_local = raft::upper_bound<DistanceT>();
for (uint32_t j = 0; j < num_distilation; j++) {
// Select a node randomly and compute the distance to it
IndexT seed_index;
IndexT seed_index = 0;
Copy link
Contributor

@mythrocks mythrocks Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely convinced that there is a problem here. seed_index's uninitialized value isn't read.

But this does indicate that this can be tightened up.

if (valid_i) {
// uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id)));
uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j)));
Expand Down