Skip to content

Commit 79535f3

Browse files
RexYingrusty1s
andauthored
Temporal sampling (#202)
* temporal sample * hetero_neighbor_sample temporal * api * remove redundant function * refactor * remove catch output * debug compile * testing env * debug * debug * debug * debug * debug * revert * node time data should not be const * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Matthias Fey <[email protected]> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Matthias Fey <[email protected]> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Matthias Fey <[email protected]> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Matthias Fey <[email protected]> * temporal template * compilation fixes Co-authored-by: Matthias Fey <[email protected]>
1 parent eafcfe0 commit 79535f3

File tree

3 files changed

+196
-31
lines changed

3 files changed

+196
-31
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 166 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -114,30 +114,41 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114114
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
115115
}
116116

117-
template <bool replace, bool directed>
117+
bool satisfy_time_constraint(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
118+
const std::string &src_node_type,
119+
const int64_t &dst_time,
120+
const int64_t &sampled_node) {
121+
// whether src -> dst obeys the time constraint
122+
try {
123+
const auto *src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
124+
return dst_time < src_time[sampled_node];
125+
}
126+
catch (int err) {
127+
// if the node type does not have timestamp, fall back to normal sampling
128+
return true;
129+
}
130+
}
131+
132+
133+
template <bool replace, bool directed, bool temporal>
118134
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
119135
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
120136
hetero_sample(const vector<node_t> &node_types,
121-
const vector<edge_t> &edge_types,
122-
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
123-
const c10::Dict<rel_t, torch::Tensor> &row_dict,
124-
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
125-
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
126-
const int64_t num_hops) {
137+
const vector<edge_t> &edge_types,
138+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
139+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
140+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
141+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
142+
const int64_t num_hops,
143+
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
144+
//bool temporal = (!node_time_dict.empty());
127145

128146
// Create a mapping to convert single string relations to edge type triplets:
129147
unordered_map<rel_t, edge_t> to_edge_type;
130148
for (const auto &k : edge_types)
131149
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
132150

133151
// Initialize some data structures for the sampling process:
134-
unordered_map<node_t, vector<int64_t>> samples_dict;
135-
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
136-
for (const auto &node_type : node_types) {
137-
samples_dict[node_type];
138-
to_local_node_dict[node_type];
139-
}
140-
141152
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
142153
for (const auto &kv : colptr_dict) {
143154
const auto &rel_type = kv.key();
@@ -146,18 +157,40 @@ hetero_sample(const vector<node_t> &node_types,
146157
edges_dict[rel_type];
147158
}
148159

160+
unordered_map<node_t, vector<int64_t>> samples_dict;
161+
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
162+
// The timestamp of the center node whose neighborhood that the sampled node
163+
// belongs to. It maps node_type to empty vector in non-temporal sampling.
164+
unordered_map<node_t, vector<int64_t>> root_time_dict;
165+
for (const auto &node_type : node_types) {
166+
samples_dict[node_type];
167+
to_local_node_dict[node_type];
168+
root_time_dict[node_type];
169+
}
170+
149171
// Add the input nodes to the output nodes:
150172
for (const auto &kv : input_node_dict) {
151173
const auto &node_type = kv.key();
152174
const torch::Tensor &input_node = kv.value();
153175
const auto *input_node_data = input_node.data_ptr<int64_t>();
176+
// dummy value. will be reset to root time if is_temporal==true
177+
auto *node_time_data = input_node.data_ptr<int64_t>();
178+
// root_time[i] stores the timestamp of the computation tree root
179+
// of the node samples[i]
180+
if (temporal) {
181+
node_time_data = node_time_dict.at(node_type).data_ptr<int64_t>();
182+
}
154183

155184
auto &samples = samples_dict.at(node_type);
156185
auto &to_local_node = to_local_node_dict.at(node_type);
186+
auto &root_time = root_time_dict.at(node_type);
157187
for (int64_t i = 0; i < input_node.numel(); i++) {
158188
const auto &v = input_node_data[i];
159189
samples.push_back(v);
160190
to_local_node.insert({v, i});
191+
if (temporal) {
192+
root_time.push_back(node_time_data[v]);
193+
}
161194
}
162195
}
163196

@@ -187,8 +220,17 @@ hetero_sample(const vector<node_t> &node_types,
187220

188221
const auto &begin = slice_dict.at(dst_node_type).first;
189222
const auto &end = slice_dict.at(dst_node_type).second;
223+
if (begin == end){
224+
continue;
225+
}
226+
// for temporal sampling, sampled src node cannot have timestamp greater
227+
// than its corresponding dst_root_time
228+
const auto &dst_root_time = root_time_dict.at(dst_node_type);
229+
auto &src_root_time = root_time_dict.at(src_node_type);
230+
190231
for (int64_t i = begin; i < end; i++) {
191232
const auto &w = dst_samples[i];
233+
const auto &dst_time = dst_root_time[i];
192234
const auto &col_start = colptr_data[w];
193235
const auto &col_end = colptr_data[w + 1];
194236
const auto col_count = col_end - col_start;
@@ -197,31 +239,56 @@ hetero_sample(const vector<node_t> &node_types,
197239
continue;
198240

199241
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
242+
// select all neighbors
200243
for (int64_t offset = col_start; offset < col_end; offset++) {
201244
const int64_t &v = row_data[offset];
245+
bool time_constraint = true;
246+
if (temporal) {
247+
time_constraint = satisfy_time_constraint(
248+
node_time_dict, src_node_type, dst_time, v);
249+
}
250+
if (!time_constraint)
251+
continue;
202252
const auto res = to_local_src_node.insert({v, src_samples.size()});
203-
if (res.second)
253+
if (res.second) {
204254
src_samples.push_back(v);
255+
if (temporal)
256+
src_root_time.push_back(dst_time);
257+
}
205258
if (directed) {
206259
cols.push_back(i);
207260
rows.push_back(res.first->second);
208261
edges.push_back(offset);
209262
}
210263
}
211264
} else if (replace) {
212-
for (int64_t j = 0; j < num_samples; j++) {
265+
// sample with replacement
266+
int64_t num_neighbors = 0;
267+
while (num_neighbors < num_samples) {
213268
const int64_t offset = col_start + uniform_randint(col_count);
214269
const int64_t &v = row_data[offset];
270+
bool time_constraint = true;
271+
if (temporal) {
272+
time_constraint = satisfy_time_constraint(
273+
node_time_dict, src_node_type, dst_time, v);
274+
}
275+
if (!time_constraint)
276+
continue;
215277
const auto res = to_local_src_node.insert({v, src_samples.size()});
216-
if (res.second)
278+
if (res.second) {
217279
src_samples.push_back(v);
280+
if (temporal)
281+
src_root_time.push_back(dst_time);
282+
}
218283
if (directed) {
219284
cols.push_back(i);
220285
rows.push_back(res.first->second);
221286
edges.push_back(offset);
222287
}
288+
num_neighbors += 1;
223289
}
224290
} else {
291+
// sample without replacement
225292
unordered_set<int64_t> rnd_indices;
226293
for (int64_t j = col_count - num_samples; j < col_count; j++) {
227294
int64_t rnd = uniform_randint(j);
@@ -231,9 +298,19 @@ hetero_sample(const vector<node_t> &node_types,
231298
}
232299
const int64_t offset = col_start + rnd;
233300
const int64_t &v = row_data[offset];
301+
bool time_constraint = true;
302+
if (temporal) {
303+
time_constraint = satisfy_time_constraint(
304+
node_time_dict, src_node_type, dst_time, v);
305+
}
306+
if (!time_constraint)
307+
continue;
234308
const auto res = to_local_src_node.insert({v, src_samples.size()});
235-
if (res.second)
309+
if (res.second) {
236310
src_samples.push_back(v);
311+
if (temporal)
312+
src_root_time.push_back(dst_time);
313+
}
237314
if (directed) {
238315
cols.push_back(i);
239316
rows.push_back(res.first->second);
@@ -290,6 +367,27 @@ hetero_sample(const vector<node_t> &node_types,
290367
from_vector<rel_t, int64_t>(edges_dict));
291368
}
292369

370+
template <bool replace, bool directed>
371+
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
372+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
373+
hetero_sample_random(const vector<node_t> &node_types,
374+
const vector<edge_t> &edge_types,
375+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
376+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
377+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
378+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
379+
const int64_t num_hops) {
380+
c10::Dict<node_t, torch::Tensor> empty_dict;
381+
return hetero_sample<replace, directed, false>(node_types,
382+
edge_types,
383+
colptr_dict,
384+
row_dict,
385+
input_node_dict,
386+
num_neighbors_dict,
387+
num_hops,
388+
empty_dict);
389+
}
390+
293391
} // namespace
294392

295393
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -320,20 +418,58 @@ hetero_neighbor_sample_cpu(
320418
const int64_t num_hops, const bool replace, const bool directed) {
321419

322420
if (replace && directed) {
323-
return hetero_sample<true, true>(node_types, edge_types, colptr_dict,
324-
row_dict, input_node_dict,
325-
num_neighbors_dict, num_hops);
421+
return hetero_sample_random<true, true>(
422+
node_types, edge_types, colptr_dict,
423+
row_dict, input_node_dict,
424+
num_neighbors_dict, num_hops);
425+
} else if (replace && !directed) {
426+
return hetero_sample_random<true, false>(
427+
node_types, edge_types, colptr_dict,
428+
row_dict, input_node_dict,
429+
num_neighbors_dict, num_hops);
430+
} else if (!replace && directed) {
431+
return hetero_sample_random<false, true>(
432+
node_types, edge_types, colptr_dict,
433+
row_dict, input_node_dict,
434+
num_neighbors_dict, num_hops);
435+
} else {
436+
return hetero_sample_random<false, false>(
437+
node_types, edge_types, colptr_dict,
438+
row_dict, input_node_dict,
439+
num_neighbors_dict, num_hops);
440+
}
441+
}
442+
443+
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
444+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
445+
hetero_neighbor_temporal_sample_cpu(
446+
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
447+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
448+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
449+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
450+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
451+
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
452+
const int64_t num_hops, const bool replace, const bool directed) {
453+
454+
if (replace && directed) {
455+
return hetero_sample<true, true, true>(
456+
node_types, edge_types, colptr_dict,
457+
row_dict, input_node_dict,
458+
num_neighbors_dict, num_hops, node_time_dict);
326459
} else if (replace && !directed) {
327-
return hetero_sample<true, false>(node_types, edge_types, colptr_dict,
328-
row_dict, input_node_dict,
329-
num_neighbors_dict, num_hops);
460+
return hetero_sample<true, false, true>(
461+
node_types, edge_types, colptr_dict,
462+
row_dict, input_node_dict,
463+
num_neighbors_dict, num_hops, node_time_dict);
330464
} else if (!replace && directed) {
331-
return hetero_sample<false, true>(node_types, edge_types, colptr_dict,
332-
row_dict, input_node_dict,
333-
num_neighbors_dict, num_hops);
465+
return hetero_sample<false, true, true>(
466+
node_types, edge_types, colptr_dict,
467+
row_dict, input_node_dict,
468+
num_neighbors_dict, num_hops, node_time_dict);
334469
} else {
335-
return hetero_sample<false, false>(node_types, edge_types, colptr_dict,
336-
row_dict, input_node_dict,
337-
num_neighbors_dict, num_hops);
470+
return hetero_sample<false, false, true>(
471+
node_types, edge_types, colptr_dict,
472+
row_dict, input_node_dict,
473+
num_neighbors_dict, num_hops, node_time_dict);
338474
}
339475
}

csrc/cpu/neighbor_sample_cpu.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,15 @@ hetero_neighbor_sample_cpu(
2222
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
2323
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
2424
const int64_t num_hops, const bool replace, const bool directed);
25+
26+
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
27+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
28+
hetero_neighbor_temporal_sample_cpu(
29+
const std::vector<node_t> &node_types,
30+
const std::vector<edge_t> &edge_types,
31+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
32+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
33+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
34+
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
35+
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
36+
const int64_t num_hops, const bool replace, const bool directed);

csrc/neighbor_sample.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,24 @@ hetero_neighbor_sample(
4040
num_neighbors_dict, num_hops, replace, directed);
4141
}
4242

43+
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
44+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
45+
hetero_neighbor_temporal_sample(
46+
const std::vector<node_t> &node_types,
47+
const std::vector<edge_t> &edge_types,
48+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
49+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
50+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
51+
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
52+
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
53+
const int64_t num_hops, const bool replace, const bool directed) {
54+
return hetero_neighbor_temporal_sample_cpu(
55+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
56+
num_neighbors_dict, node_time_dict, num_hops, replace, directed);
57+
}
58+
4359
static auto registry =
4460
torch::RegisterOperators()
4561
.op("torch_sparse::neighbor_sample", &neighbor_sample)
46-
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample);
62+
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample)
63+
.op("torch_sparse::hetero_neighbor_temporal_sample", &hetero_neighbor_temporal_sample);

0 commit comments

Comments
 (0)