Skip to content

Commit 6411798

Browse files
authored
Fixes for stricter compilers (#1703)
Added [[fallthrough]] attribute to switch cases in team_sum function for clarity and correctness. Initialized seed_index to 0 to avoid potential uninitialized variable usage and remove an unused variable Authors: - Max Buckley (https://github.com/maxwbuckley) - MithunR (https://github.com/mythrocks) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #1703
1 parent f1e19af commit 6411798

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

cpp/src/neighbors/detail/cagra/device_common.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55
#pragma once
@@ -76,11 +76,11 @@ template <typename T>
7676
RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size_bitshift) -> T
7777
{
7878
switch (team_size_bitshift) {
79-
case 5: x += raft::shfl_xor(x, 16);
80-
case 4: x += raft::shfl_xor(x, 8);
81-
case 3: x += raft::shfl_xor(x, 4);
82-
case 2: x += raft::shfl_xor(x, 2);
83-
case 1: x += raft::shfl_xor(x, 1);
79+
case 5: x += raft::shfl_xor(x, 16); [[fallthrough]];
80+
case 4: x += raft::shfl_xor(x, 8); [[fallthrough]];
81+
case 3: x += raft::shfl_xor(x, 4); [[fallthrough]];
82+
case 2: x += raft::shfl_xor(x, 2); [[fallthrough]];
83+
case 1: x += raft::shfl_xor(x, 1); [[fallthrough]];
8484
default: return x;
8585
}
8686
}
@@ -106,7 +106,6 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
106106
{
107107
const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem();
108108
const auto max_i = raft::round_up_safe<uint32_t>(num_pickup, warp_size >> team_size_bits);
109-
const auto compute_distance = dataset_desc.compute_distance_impl;
110109

111110
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) {
112111
const bool valid_i = (i < num_pickup);
@@ -115,7 +114,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
115114
DistanceT best_norm2_team_local = raft::upper_bound<DistanceT>();
116115
for (uint32_t j = 0; j < num_distilation; j++) {
117116
// Select a node randomly and compute the distance to it
118-
IndexT seed_index;
117+
IndexT seed_index = 0;
119118
if (valid_i) {
120119
// uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id)));
121120
uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j)));

0 commit comments

Comments
 (0)