11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2023-2024 , NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2023-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55#pragma once
@@ -76,11 +76,11 @@ template <typename T>
7676RAFT_DEVICE_INLINE_FUNCTION auto team_sum (T x, uint32_t team_size_bitshift) -> T
7777{
7878 switch (team_size_bitshift) {
79- case 5 : x += raft::shfl_xor (x, 16 );
80- case 4 : x += raft::shfl_xor (x, 8 );
81- case 3 : x += raft::shfl_xor (x, 4 );
82- case 2 : x += raft::shfl_xor (x, 2 );
83- case 1 : x += raft::shfl_xor (x, 1 );
79+ case 5 : x += raft::shfl_xor (x, 16 ); [[fallthrough]];
80+ case 4 : x += raft::shfl_xor (x, 8 ); [[fallthrough]];
81+ case 3 : x += raft::shfl_xor (x, 4 ); [[fallthrough]];
82+ case 2 : x += raft::shfl_xor (x, 2 ); [[fallthrough]];
83+ case 1 : x += raft::shfl_xor (x, 1 ); [[fallthrough]];
8484 default : return x;
8585 }
8686}
@@ -106,7 +106,6 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
106106{
107107 const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem ();
108108 const auto max_i = raft::round_up_safe<uint32_t >(num_pickup, warp_size >> team_size_bits);
109- const auto compute_distance = dataset_desc.compute_distance_impl ;
110109
111110 for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) {
112111 const bool valid_i = (i < num_pickup);
@@ -115,7 +114,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
115114 DistanceT best_norm2_team_local = raft::upper_bound<DistanceT>();
116115 for (uint32_t j = 0 ; j < num_distilation; j++) {
117116 // Select a node randomly and compute the distance to it
118- IndexT seed_index;
117+ IndexT seed_index = 0 ;
119118 if (valid_i) {
120119 // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id)));
121120 uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j)));
0 commit comments