@@ -238,17 +238,24 @@ hetero_sample(const vector<node_t> &node_types,
238238 if (temporal) {
239239 if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
240240 continue ;
241- }
242- const auto res = to_local_src_node.insert ({v, src_samples.size ()});
243- if (res.second ) {
241+ // force disjoint of computation tree
242+ // note that the sampling always needs to have directed=True
243+ // for temporal case
244+ // to_local_src_node is not used for temporal / directed case
244245 src_samples.push_back (v);
245- if (temporal)
246- src_root_time.push_back (dst_time);
247- }
248- if (directed) {
246+ src_root_time.push_back (dst_time);
249247 cols.push_back (i);
250- rows.push_back (res. first -> second );
248+ rows.push_back (src_samples. size () - 1 );
251249 edges.push_back (offset);
250+ } else {
251+ const auto res = to_local_src_node.insert ({v, src_samples.size ()});
252+ if (res.second )
253+ src_samples.push_back (v);
254+ if (directed) {
255+ cols.push_back (i);
256+ rows.push_back (res.first ->second );
257+ edges.push_back (offset);
258+ }
252259 }
253260 }
254261 } else if (replace) {
@@ -261,17 +268,23 @@ hetero_sample(const vector<node_t> &node_types,
261268 // TODO Infinity loop if no neighbor satisfies time constraint:
262269 if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
263270 continue ;
264- }
265- const auto res = to_local_src_node. insert ({v, src_samples. size ()});
266- if (res. second ) {
271+ // force disjoint of computation tree
272+ // note that the sampling always needs to have directed=True
273+ // for temporal case
267274 src_samples.push_back (v);
268- if (temporal)
269- src_root_time.push_back (dst_time);
270- }
271- if (directed) {
275+ src_root_time.push_back (dst_time);
272276 cols.push_back (i);
273- rows.push_back (res. first -> second );
277+ rows.push_back (src_samples. size () - 1 );
274278 edges.push_back (offset);
279+ } else {
280+ const auto res = to_local_src_node.insert ({v, src_samples.size ()});
281+ if (res.second )
282+ src_samples.push_back (v);
283+ if (directed) {
284+ cols.push_back (i);
285+ rows.push_back (res.first ->second );
286+ edges.push_back (offset);
287+ }
275288 }
276289 num_neighbors += 1 ;
277290 }
@@ -289,17 +302,23 @@ hetero_sample(const vector<node_t> &node_types,
289302 if (temporal) {
290303 if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
291304 continue ;
292- }
293- const auto res = to_local_src_node. insert ({v, src_samples. size ()});
294- if (res. second ) {
305+ // force disjoint of computation tree
306+ // note that the sampling always needs to have directed=True
307+ // for temporal case
295308 src_samples.push_back (v);
296- if (temporal)
297- src_root_time.push_back (dst_time);
298- }
299- if (directed) {
309+ src_root_time.push_back (dst_time);
300310 cols.push_back (i);
301- rows.push_back (res. first -> second );
311+ rows.push_back (src_samples. size () - 1 );
302312 edges.push_back (offset);
313+ } else {
314+ const auto res = to_local_src_node.insert ({v, src_samples.size ()});
315+ if (res.second )
316+ src_samples.push_back (v);
317+ if (directed) {
318+ cols.push_back (i);
319+ rows.push_back (res.first ->second );
320+ edges.push_back (offset);
321+ }
303322 }
304323 }
305324 }
@@ -412,21 +431,19 @@ hetero_temporal_neighbor_sample_cpu(
412431 const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
413432 const c10::Dict<node_t , torch::Tensor> &node_time_dict,
414433 const int64_t num_hops, const bool replace, const bool directed) {
415-
416- if (replace && directed) {
434+ AT_ASSERTM (directed, " Temporal sampling requires 'directed' sampling" )
435+ if (replace) {
436+ // We assume that directed = True for temporal sampling
437+ // The current implementation uses disjoint computation trees
438+ // to tackle the case of the same node sampled having different
439+ // root time constraint.
440+ // In future, we could extend to directed = False case,
441+ // allowing additional edges within each computation tree.
417442 return hetero_sample<true , true , true >(
418443 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
419444 num_neighbors_dict, node_time_dict, num_hops);
420- } else if (replace && !directed) {
421- return hetero_sample<true , false , true >(
422- node_types, edge_types, colptr_dict, row_dict, input_node_dict,
423- num_neighbors_dict, node_time_dict, num_hops);
424- } else if (!replace && directed) {
425- return hetero_sample<false , true , true >(
426- node_types, edge_types, colptr_dict, row_dict, input_node_dict,
427- num_neighbors_dict, node_time_dict, num_hops);
428445 } else {
429- return hetero_sample<false , false , true >(
446+ return hetero_sample<false , true , true >(
430447 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
431448 num_neighbors_dict, node_time_dict, num_hops);
432449 }
0 commit comments