Skip to content

Commit bd6d4a9

Browse files
authored
Fix duplicate indices in batch NN Descent (#702)
### Purpose of this PR Handling duplicate indices in batch NN Descent graph. Resolves the following issues - rapidsai/raft#2450 - #559 - #753 ### Notes - Also fixed in RAFT [here](rapidsai/raft#2586) for current use with cuML Authors: - Jinsol Park (https://github.com/jinsolp) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) URL: #702
1 parent f14266c commit bd6d4a9

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

cpp/src/neighbors/detail/nn_descent.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ class BloomFilter {
244244
}
245245
}
246246

247+
void set_nrow(size_t nrow) { nrow_ = nrow; }
248+
247249
bool check(size_t list_id, Index_t key)
248250
{
249251
bool is_present = true;
@@ -1273,7 +1275,9 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
12731275
cudaStream_t stream = raft::resource::get_cuda_stream(res);
12741276
nrow_ = nrow;
12751277
graph_.nrow = nrow;
1276-
graph_.h_graph = (InternalID_t<Index_t>*)output_graph;
1278+
graph_.bloom_filter.set_nrow(nrow);
1279+
update_counter_ = 0;
1280+
graph_.h_graph = (InternalID_t<Index_t>*)output_graph;
12771281

12781282
cudaPointerAttributes data_ptr_attr;
12791283
RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data));

cpp/src/neighbors/detail/nn_descent_batch.cuh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ void get_global_nearest_k(
170170

171171
if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; }
172172
thrust::copy(raft::resource::get_thrust_policy(res),
173-
nearest_clusters_idx.data_handle() + i * batch_size_ * k,
174-
nearest_clusters_idx.data_handle() + (i + 1) * batch_size_ * k,
173+
nearest_clusters_idx.data_handle() + i * batch_size * k,
174+
nearest_clusters_idx.data_handle() + (i * batch_size + batch_size_) * k,
175175
nearest_clusters_idxt.data_handle());
176-
raft::copy(global_nearest_cluster.data_handle() + i * batch_size_ * k,
176+
raft::copy(global_nearest_cluster.data_handle() + i * batch_size * k,
177177
nearest_clusters_idxt.data_handle(),
178178
batch_size_ * k,
179179
resource::get_cuda_stream(res));
@@ -650,7 +650,8 @@ void batch_build(raft::resources const& res,
650650
.internal_node_degree = extended_intermediate_degree,
651651
.max_iterations = params.max_iterations,
652652
.termination_threshold = params.termination_threshold,
653-
.output_graph_degree = graph_degree};
653+
.output_graph_degree = graph_degree,
654+
.metric = params.metric};
654655

655656
auto global_indices_h = raft::make_managed_matrix<IdxT, int64_t>(res, num_rows, graph_degree);
656657
auto global_distances_h = raft::make_managed_matrix<float, int64_t>(res, num_rows, graph_degree);

cpp/tests/neighbors/ann_nn_descent.cuh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class AnnNNDescentBatchTest : public ::testing::TestWithParam<AnnNNDescentBatchI
241241
index_params.metric = ps.metric;
242242
index_params.graph_degree = ps.graph_degree;
243243
index_params.intermediate_graph_degree = 2 * ps.graph_degree;
244-
index_params.max_iterations = 10;
244+
index_params.max_iterations = 100;
245245
index_params.return_distances = true;
246246
index_params.n_clusters = ps.recall_cluster.second;
247247

@@ -287,8 +287,7 @@ class AnnNNDescentBatchTest : public ::testing::TestWithParam<AnnNNDescentBatchI
287287
ps.graph_degree,
288288
0.01,
289289
min_recall,
290-
true,
291-
static_cast<size_t>(ps.graph_degree * 0.1)));
290+
true));
292291
}
293292
}
294293

@@ -328,8 +327,6 @@ const std::vector<AnnNNDescentInputs> inputs =
328327
{false, true},
329328
{0.90});
330329

331-
// TODO : Investigate why this test is failing Reference issue https
332-
// : // github.com/rapidsai/raft/issues/2450
333330
const std::vector<AnnNNDescentBatchInputs> inputsBatch =
334331
raft::util::itertools::product<AnnNNDescentBatchInputs>(
335332
{std::make_pair(0.9, 3lu), std::make_pair(0.9, 2lu)}, // min_recall, n_clusters

cpp/tests/neighbors/ann_nn_descent/test_float_uint32_t.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ namespace cuvs::neighbors::nn_descent {
2323
typedef AnnNNDescentTest<float, float, std::uint32_t> AnnNNDescentTestF_U32;
2424
TEST_P(AnnNNDescentTestF_U32, AnnNNDescent) { this->testNNDescent(); }
2525

26-
// typedef AnnNNDescentBatchTest<float, float, std::uint32_t> AnnNNDescentBatchTestF_U32;
27-
// TEST_P(AnnNNDescentBatchTestF_U32, AnnNNDescentBatch) { this->testNNDescentBatch(); }
26+
typedef AnnNNDescentBatchTest<float, float, std::uint32_t> AnnNNDescentBatchTestF_U32;
27+
TEST_P(AnnNNDescentBatchTestF_U32, AnnNNDescentBatch) { this->testNNDescentBatch(); }
2828

2929
INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs));
30-
// INSTANTIATE_TEST_CASE_P(AnnNNDescentBatchTest,
31-
// AnnNNDescentBatchTestF_U32,
32-
// ::testing::ValuesIn(inputsBatch));
30+
INSTANTIATE_TEST_CASE_P(AnnNNDescentBatchTest,
31+
AnnNNDescentBatchTestF_U32,
32+
::testing::ValuesIn(inputsBatch));
3333

3434
} // namespace cuvs::neighbors::nn_descent

0 commit comments

Comments
 (0)