Skip to content

Commit 2f3ecc3

Browse files
committed
fix sample with replacement in case of isolated nodes
1 parent 87c88d9 commit 2f3ecc3

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

csrc/cpu/sample_cpu.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,17 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
6363
row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
6464
row_count = row_end - row_start;
6565

66-
for (int64_t j = 0; j < num_neighbors; j++) {
67-
e = row_start + rand() % row_count;
68-
c = col_data[e];
69-
70-
if (n_id_map.count(c) == 0) {
71-
n_id_map[c] = n_ids.size();
72-
n_ids.push_back(c);
66+
if (row_count > 0) {
67+
for (int64_t j = 0; j < num_neighbors; j++) {
68+
e = row_start + rand() % row_count;
69+
c = col_data[e];
70+
71+
if (n_id_map.count(c) == 0) {
72+
n_id_map[c] = n_ids.size();
73+
n_ids.push_back(c);
74+
}
75+
cols[i].push_back(std::make_tuple(n_id_map[c], e));
7376
}
74-
cols[i].push_back(std::make_tuple(n_id_map[c], e));
7577
}
7678
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
7779
}

0 commit comments

Comments
 (0)