Skip to content

Commit 122811a

Browse files
ZenoTanrusty1s
andauthored
rand() -> torch::randint() (#217)
* torch randint * update * fix type * update * add test Co-authored-by: rusty1s <[email protected]>
1 parent 124bc09 commit 122811a

File tree

7 files changed

+45
-32
lines changed

7 files changed

+45
-32
lines changed

csrc/cpu/ego_sample_cpu.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
1919
torch::Tensor idx, int64_t depth,
2020
int64_t num_neighbors, bool replace) {
2121

22-
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
23-
2422
std::vector<torch::Tensor> out_rowptrs(idx.numel() + 1);
2523
std::vector<torch::Tensor> out_cols(idx.numel());
2624
std::vector<torch::Tensor> out_n_ids(idx.numel());
@@ -56,14 +54,14 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
5654
}
5755
} else if (replace) {
5856
for (int64_t j = 0; j < num_neighbors; j++) {
59-
w = col_data[row_start + (rand() % row_count)];
57+
w = col_data[row_start + uniform_randint(row_count)];
6058
n_id_set.insert(w);
6159
n_ids.push_back(w);
6260
}
6361
} else {
6462
std::unordered_set<int64_t> perm;
6563
for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
66-
if (!perm.insert(rand() % j).second) {
64+
if (!perm.insert(uniform_randint(j)).second) {
6765
perm.insert(j);
6866
}
6967
}

csrc/cpu/hgt_sample_cpu.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
105105
const c10::Dict<node_t, vector<int64_t>> &num_samples_dict,
106106
const int64_t num_hops) {
107107

108-
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
109-
110108
// Create a mapping to convert single string relations to edge type triplets:
111109
unordered_map<rel_t, edge_t> to_edge_type;
112110
for (const auto &kv : colptr_dict) {

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
1515
sample(const torch::Tensor &colptr, const torch::Tensor &row,
1616
const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {
1717

18-
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
19-
2018
// Initialize some data structures for the sampling process:
2119
vector<int64_t> samples;
2220
unordered_map<int64_t, int64_t> to_local_node;
@@ -59,7 +57,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
5957
}
6058
} else if (replace) {
6159
for (int64_t j = 0; j < num_samples; j++) {
62-
const int64_t offset = col_start + rand() % col_count;
60+
const int64_t offset = col_start + uniform_randint(col_count);
6361
const int64_t &v = row_data[offset];
6462
const auto res = to_local_node.insert({v, samples.size()});
6563
if (res.second)
@@ -73,7 +71,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
7371
} else {
7472
unordered_set<int64_t> rnd_indices;
7573
for (int64_t j = col_count - num_samples; j < col_count; j++) {
76-
int64_t rnd = rand() % j;
74+
int64_t rnd = uniform_randint(j);
7775
if (!rnd_indices.insert(rnd).second) {
7876
rnd = j;
7977
rnd_indices.insert(j);
@@ -127,8 +125,6 @@ hetero_sample(const vector<node_t> &node_types,
127125
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
128126
const int64_t num_hops) {
129127

130-
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
131-
132128
// Create a mapping to convert single string relations to edge type triplets:
133129
unordered_map<rel_t, edge_t> to_edge_type;
134130
for (const auto &k : edge_types)
@@ -180,8 +176,10 @@ hetero_sample(const vector<node_t> &node_types,
180176
auto &src_samples = samples_dict.at(src_node_type);
181177
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
182178

183-
const auto *colptr_data = ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
184-
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
179+
const auto *colptr_data =
180+
((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
181+
const auto *row_data =
182+
((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
185183

186184
auto &rows = rows_dict.at(rel_type);
187185
auto &cols = cols_dict.at(rel_type);
@@ -212,7 +210,7 @@ hetero_sample(const vector<node_t> &node_types,
212210
}
213211
} else if (replace) {
214212
for (int64_t j = 0; j < num_samples; j++) {
215-
const int64_t offset = col_start + rand() % col_count;
213+
const int64_t offset = col_start + uniform_randint(col_count);
216214
const int64_t &v = row_data[offset];
217215
const auto res = to_local_src_node.insert({v, src_samples.size()});
218216
if (res.second)
@@ -226,7 +224,7 @@ hetero_sample(const vector<node_t> &node_types,
226224
} else {
227225
unordered_set<int64_t> rnd_indices;
228226
for (int64_t j = col_count - num_samples; j < col_count; j++) {
229-
int64_t rnd = rand() % j;
227+
int64_t rnd = uniform_randint(j);
230228
if (!rnd_indices.insert(rnd).second) {
231229
rnd = j;
232230
rnd_indices.insert(j);
@@ -262,7 +260,8 @@ hetero_sample(const vector<node_t> &node_types,
262260
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
263261

264262
const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
265-
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
263+
const auto *row_data =
264+
((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
266265

267266
auto &rows = rows_dict.at(rel_type);
268267
auto &cols = cols_dict.at(rel_type);

csrc/cpu/sample_cpu.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
1515
CHECK_CPU(idx);
1616
CHECK_INPUT(idx.dim() == 1);
1717

18-
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
19-
2018
auto rowptr_data = rowptr.data_ptr<int64_t>();
2119
auto col_data = col.data_ptr<int64_t>();
2220
auto idx_data = idx.data_ptr<int64_t>();
@@ -69,7 +67,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
6967

7068
if (row_count > 0) {
7169
for (int64_t j = 0; j < num_neighbors; j++) {
72-
e = row_start + rand() % row_count;
70+
e = row_start + uniform_randint(row_count);
7371
c = col_data[e];
7472

7573
if (n_id_map.count(c) == 0) {
@@ -96,7 +94,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
9694
} else { // See: https://www.nowherenearithaca.com/2013/05/
9795
// robert-floyds-tiny-and-beautiful.html
9896
for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
99-
if (!perm.insert(rand() % j).second)
97+
if (!perm.insert(uniform_randint(j)).second)
10098
perm.insert(j);
10199
}
102100
}

csrc/cpu/utils.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict,
3535
return out_dict;
3636
}
3737

38+
inline int64_t uniform_randint(int64_t low, int64_t high) {
39+
CHECK_LT(low, high);
40+
auto options = torch::TensorOptions().dtype(torch::kInt64);
41+
auto ret = torch::randint(low, high, {1}, options);
42+
auto ptr = ret.data_ptr<int64_t>();
43+
return *ptr;
44+
}
45+
46+
inline int64_t uniform_randint(int64_t high) {
47+
return uniform_randint(0, high);
48+
}
49+
3850
inline torch::Tensor
3951
choice(int64_t population, int64_t num_samples, bool replace = false,
4052
torch::optional<torch::Tensor> weight = torch::nullopt) {
@@ -52,7 +64,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
5264
const auto out = torch::empty(num_samples, at::kLong);
5365
auto *out_data = out.data_ptr<int64_t>();
5466
for (int64_t i = 0; i < num_samples; i++) {
55-
out_data[i] = rand() % population;
67+
out_data[i] = uniform_randint(population);
5668
}
5769
return out;
5870

@@ -64,7 +76,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
6476
auto *out_data = out.data_ptr<int64_t>();
6577
std::unordered_set<int64_t> samples;
6678
for (int64_t i = population - num_samples; i < population; i++) {
67-
int64_t sample = rand() % i;
79+
int64_t sample = uniform_randint(i);
6880
if (!samples.insert(sample).second) {
6981
sample = i;
7082
samples.insert(sample);
@@ -86,7 +98,7 @@ uniform_choice(const int64_t population, const int64_t num_samples,
8698

8799
if (replace) {
88100
for (int64_t i = 0; i < num_samples; i++) {
89-
const int64_t &v = idx_data[rand() % population];
101+
const int64_t &v = idx_data[uniform_randint(population)];
90102
if (to_local_node->insert({v, samples->size()}).second)
91103
samples->push_back(v);
92104
}
@@ -99,7 +111,7 @@ uniform_choice(const int64_t population, const int64_t num_samples,
99111
} else {
100112
std::unordered_set<int64_t> indices;
101113
for (int64_t i = population - num_samples; i < population; i++) {
102-
int64_t j = rand() % i;
114+
int64_t j = uniform_randint(i);
103115
if (!indices.insert(j).second) {
104116
j = i;
105117
indices.insert(j);

csrc/extensions.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,2 @@
11
#include "macros.h"
22
#include <torch/extension.h>
3-
4-
// for getpid()
5-
#ifdef _WIN32
6-
#include <process.h>
7-
#else
8-
#include <unistd.h>
9-
#endif

test/test_neighbor_sample.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,18 @@ def test_neighbor_sample():
2525
assert out[0].tolist() == [1, 0]
2626
assert out[1].tolist() == [1]
2727
assert out[2].tolist() == [0]
28+
29+
30+
def test_neighbor_sample_seed():
31+
colptr = torch.tensor([0, 3, 6, 9])
32+
row = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
33+
input_nodes = torch.tensor([0, 1])
34+
35+
torch.manual_seed(42)
36+
out1 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
37+
38+
torch.manual_seed(42)
39+
out2 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
40+
41+
for data1, data2 in zip(out1, out2):
42+
assert data1.tolist() == data2.tolist()

0 commit comments

Comments
 (0)