Skip to content

Commit d670cdd

Browse files
RexYingZecheng Zhangrusty1s
authored
Temporal sample disable undirected (#247)
* disable undirected for temporal sampling * disjoint sampling for temporal * fix repeated node index * compile fix * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang <[email protected]> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang <[email protected]> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang <[email protected]> * comments on directed to be true * add directed in API * comments * minor function signature fix * Update csrc/cpu/neighbor_sample_cpu.cpp * Update csrc/neighbor_sample.cpp Co-authored-by: Zecheng Zhang <[email protected]> Co-authored-by: Matthias Fey <[email protected]>
1 parent 7dbc51c commit d670cdd

File tree

1 file changed

+52
-35
lines changed

1 file changed

+52
-35
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,24 @@ hetero_sample(const vector<node_t> &node_types,
238238
if (temporal) {
239239
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
240240
continue;
241-
}
242-
const auto res = to_local_src_node.insert({v, src_samples.size()});
243-
if (res.second) {
241+
// force disjoint of computation tree
242+
// note that the sampling always needs to have directed=True
243+
// for temporal case
244+
// to_local_src_node is not used for temporal / directed case
244245
src_samples.push_back(v);
245-
if (temporal)
246-
src_root_time.push_back(dst_time);
247-
}
248-
if (directed) {
246+
src_root_time.push_back(dst_time);
249247
cols.push_back(i);
250-
rows.push_back(res.first->second);
248+
rows.push_back(src_samples.size() - 1);
251249
edges.push_back(offset);
250+
} else {
251+
const auto res = to_local_src_node.insert({v, src_samples.size()});
252+
if (res.second)
253+
src_samples.push_back(v);
254+
if (directed) {
255+
cols.push_back(i);
256+
rows.push_back(res.first->second);
257+
edges.push_back(offset);
258+
}
252259
}
253260
}
254261
} else if (replace) {
@@ -261,17 +268,23 @@ hetero_sample(const vector<node_t> &node_types,
261268
// TODO Infinity loop if no neighbor satisfies time constraint:
262269
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
263270
continue;
264-
}
265-
const auto res = to_local_src_node.insert({v, src_samples.size()});
266-
if (res.second) {
271+
// force disjoint of computation tree
272+
// note that the sampling always needs to have directed=True
273+
// for temporal case
267274
src_samples.push_back(v);
268-
if (temporal)
269-
src_root_time.push_back(dst_time);
270-
}
271-
if (directed) {
275+
src_root_time.push_back(dst_time);
272276
cols.push_back(i);
273-
rows.push_back(res.first->second);
277+
rows.push_back(src_samples.size() - 1);
274278
edges.push_back(offset);
279+
} else {
280+
const auto res = to_local_src_node.insert({v, src_samples.size()});
281+
if (res.second)
282+
src_samples.push_back(v);
283+
if (directed) {
284+
cols.push_back(i);
285+
rows.push_back(res.first->second);
286+
edges.push_back(offset);
287+
}
275288
}
276289
num_neighbors += 1;
277290
}
@@ -289,17 +302,23 @@ hetero_sample(const vector<node_t> &node_types,
289302
if (temporal) {
290303
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
291304
continue;
292-
}
293-
const auto res = to_local_src_node.insert({v, src_samples.size()});
294-
if (res.second) {
305+
// force disjoint of computation tree
306+
// note that the sampling always needs to have directed=True
307+
// for temporal case
295308
src_samples.push_back(v);
296-
if (temporal)
297-
src_root_time.push_back(dst_time);
298-
}
299-
if (directed) {
309+
src_root_time.push_back(dst_time);
300310
cols.push_back(i);
301-
rows.push_back(res.first->second);
311+
rows.push_back(src_samples.size() - 1);
302312
edges.push_back(offset);
313+
} else {
314+
const auto res = to_local_src_node.insert({v, src_samples.size()});
315+
if (res.second)
316+
src_samples.push_back(v);
317+
if (directed) {
318+
cols.push_back(i);
319+
rows.push_back(res.first->second);
320+
edges.push_back(offset);
321+
}
303322
}
304323
}
305324
}
@@ -412,21 +431,19 @@ hetero_temporal_neighbor_sample_cpu(
412431
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
413432
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
414433
const int64_t num_hops, const bool replace, const bool directed) {
415-
416-
if (replace && directed) {
434+
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling")
435+
if (replace) {
436+
// We assume that directed = True for temporal sampling
437+
// The current implementation uses disjoint computation trees
438+
// to tackle the case of the same node sampled having different
439+
// root time constraint.
440+
// In future, we could extend to directed = False case,
441+
// allowing additional edges within each computation tree.
417442
return hetero_sample<true, true, true>(
418443
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
419444
num_neighbors_dict, node_time_dict, num_hops);
420-
} else if (replace && !directed) {
421-
return hetero_sample<true, false, true>(
422-
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
423-
num_neighbors_dict, node_time_dict, num_hops);
424-
} else if (!replace && directed) {
425-
return hetero_sample<false, true, true>(
426-
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
427-
num_neighbors_dict, node_time_dict, num_hops);
428445
} else {
429-
return hetero_sample<false, false, true>(
446+
return hetero_sample<false, true, true>(
430447
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
431448
num_neighbors_dict, node_time_dict, num_hops);
432449
}

0 commit comments

Comments
 (0)