@@ -114,16 +114,13 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114114 from_vector<int64_t >(cols), from_vector<int64_t >(edges));
115115}
116116
117- bool satisfy_time_constraint (
118- const c10::Dict<node_t , torch::Tensor> &node_time_dict,
119- const node_t &src_node_type, const int64_t &dst_time,
120- const int64_t &src_node) {
121- // whether src -> dst obeys the time constraint
122- try {
117+ inline bool satisfy_time (const c10::Dict<node_t , torch::Tensor> &node_time_dict,
118+ const node_t &src_node_type, const int64_t &dst_time,
119+ const int64_t &src_node) {
120+ try { // Check whether src -> dst obeys the time constraint:
123121 auto src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124122 return dst_time < src_time[src_node];
125- } catch (int err) {
126- // if the node type does not have timestamp, fall back to normal sampling
123+ } catch (int err) { // If no time is given, fall back to normal sampling:
127124 return true ;
128125 }
129126}
@@ -137,8 +134,9 @@ hetero_sample(const vector<node_t> &node_types,
137134 const c10::Dict<rel_t , torch::Tensor> &row_dict,
138135 const c10::Dict<node_t , torch::Tensor> &input_node_dict,
139136 const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
140- const int64_t num_hops,
141- const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
137+ const c10::Dict<node_t , torch::Tensor> &node_time_dict,
138+ const int64_t num_hops) {
139+
142140 // Create a mapping to convert single string relations to edge type triplets:
143141 unordered_map<rel_t , edge_t > to_edge_type;
144142 for (const auto &k : edge_types)
@@ -155,8 +153,6 @@ hetero_sample(const vector<node_t> &node_types,
155153
156154 unordered_map<node_t , vector<int64_t >> samples_dict;
157155 unordered_map<node_t , unordered_map<int64_t , int64_t >> to_local_node_dict;
158- // The timestamp of the center node whose neighborhood that the sampled node
159- // belongs to. It maps node_type to empty vector in non-temporal sampling.
160156 unordered_map<node_t , vector<int64_t >> root_time_dict;
161157 for (const auto &node_type : node_types) {
162158 samples_dict[node_type];
@@ -169,10 +165,7 @@ hetero_sample(const vector<node_t> &node_types,
169165 const auto &node_type = kv.key ();
170166 const torch::Tensor &input_node = kv.value ();
171167 const auto *input_node_data = input_node.data_ptr <int64_t >();
172- // dummy value. will be reset to root time if is_temporal==true
173168 int64_t *node_time_data;
174- // root_time[i] stores the timestamp of the computation tree root
175- // of the node samples[i]
176169 if (temporal) {
177170 torch::Tensor node_time = node_time_dict.at (node_type);
178171 node_time_data = node_time.data_ptr <int64_t >();
@@ -185,9 +178,8 @@ hetero_sample(const vector<node_t> &node_types,
185178 const auto &v = input_node_data[i];
186179 samples.push_back (v);
187180 to_local_node.insert ({v, i});
188- if (temporal) {
181+ if (temporal)
189182 root_time.push_back (node_time_data[v]);
190- }
191183 }
192184 }
193185
@@ -217,11 +209,12 @@ hetero_sample(const vector<node_t> &node_types,
217209
218210 const auto &begin = slice_dict.at (dst_node_type).first ;
219211 const auto &end = slice_dict.at (dst_node_type).second ;
220- if (begin == end) {
212+
213+ if (begin == end)
221214 continue ;
222- }
223- // for temporal sampling, sampled src node cannot have timestamp greater
224- // than its corresponding dst_root_time
215+
216+ // For temporal sampling, sampled nodes cannot have a timestamp greater
217+ // than the timestamp of the root nodes.
225218 const auto &dst_root_time = root_time_dict.at (dst_node_type);
226219 auto &src_root_time = root_time_dict.at (src_node_type);
227220
@@ -236,16 +229,13 @@ hetero_sample(const vector<node_t> &node_types,
236229 continue ;
237230
238231 if ((num_samples < 0 ) || (!replace && (num_samples >= col_count))) {
239- // select all neighbors
232+ // Select all neighbors:
240233 for (int64_t offset = col_start; offset < col_end; offset++) {
241234 const int64_t &v = row_data[offset];
242- bool time_constraint = true ;
243235 if (temporal) {
244- time_constraint = satisfy_time_constraint (
245- node_time_dict, src_node_type, dst_time, v) ;
236+ if (! satisfy_time (node_time_dict, src_node_type, dst_time, v))
237+ continue ;
246238 }
247- if (!time_constraint)
248- continue ;
249239 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
250240 if (res.second ) {
251241 src_samples.push_back (v);
@@ -259,18 +249,16 @@ hetero_sample(const vector<node_t> &node_types,
259249 }
260250 }
261251 } else if (replace) {
262- // sample with replacement
252+ // Sample with replacement:
263253 int64_t num_neighbors = 0 ;
264254 while (num_neighbors < num_samples) {
265255 const int64_t offset = col_start + uniform_randint (col_count);
266256 const int64_t &v = row_data[offset];
267- bool time_constraint = true ;
268257 if (temporal) {
269- time_constraint = satisfy_time_constraint (
270- node_time_dict, src_node_type, dst_time, v);
258+ // TODO Infinity loop if no neighbor satisfies time constraint:
259+ if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
260+ continue ;
271261 }
272- if (!time_constraint)
273- continue ;
274262 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
275263 if (res.second ) {
276264 src_samples.push_back (v);
@@ -285,7 +273,7 @@ hetero_sample(const vector<node_t> &node_types,
285273 num_neighbors += 1 ;
286274 }
287275 } else {
288- // sample without replacement
276+ // Sample without replacement:
289277 unordered_set<int64_t > rnd_indices;
290278 for (int64_t j = col_count - num_samples; j < col_count; j++) {
291279 int64_t rnd = uniform_randint (j);
@@ -295,13 +283,10 @@ hetero_sample(const vector<node_t> &node_types,
295283 }
296284 const int64_t offset = col_start + rnd;
297285 const int64_t &v = row_data[offset];
298- bool time_constraint = true ;
299286 if (temporal) {
300- time_constraint = satisfy_time_constraint (
301- node_time_dict, src_node_type, dst_time, v) ;
287+ if (! satisfy_time (node_time_dict, src_node_type, dst_time, v))
288+ continue ;
302289 }
303- if (!time_constraint)
304- continue ;
305290 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
306291 if (res.second ) {
307292 src_samples.push_back (v);
@@ -364,22 +349,6 @@ hetero_sample(const vector<node_t> &node_types,
364349 from_vector<rel_t , int64_t >(edges_dict));
365350}
366351
367- template <bool replace, bool directed>
368- tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
369- c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
370- hetero_sample_random (
371- const vector<node_t > &node_types, const vector<edge_t > &edge_types,
372- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
373- const c10::Dict<rel_t , torch::Tensor> &row_dict,
374- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
375- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
376- const int64_t num_hops) {
377- c10::Dict<node_t , torch::Tensor> empty_dict;
378- return hetero_sample<replace, directed, false >(
379- node_types, edge_types, colptr_dict, row_dict, input_node_dict,
380- num_neighbors_dict, num_hops, empty_dict);
381- }
382-
383352} // namespace
384353
385354tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -409,28 +378,30 @@ hetero_neighbor_sample_cpu(
409378 const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
410379 const int64_t num_hops, const bool replace, const bool directed) {
411380
381+ c10::Dict<node_t , torch::Tensor> node_time_dict; // Empty dictionary.
382+
412383 if (replace && directed) {
413- return hetero_sample_random <true , true >(node_types, edge_types, colptr_dict,
414- row_dict, input_node_dict,
415- num_neighbors_dict, num_hops);
384+ return hetero_sample <true , true , false >(
385+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
386+ num_neighbors_dict, node_time_dict , num_hops);
416387 } else if (replace && !directed) {
417- return hetero_sample_random <true , false >(
388+ return hetero_sample <true , false , false >(
418389 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
419- num_neighbors_dict, num_hops);
390+ num_neighbors_dict, node_time_dict, num_hops);
420391 } else if (!replace && directed) {
421- return hetero_sample_random <false , true >(
392+ return hetero_sample <false , true , false >(
422393 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
423- num_neighbors_dict, num_hops);
394+ num_neighbors_dict, node_time_dict, num_hops);
424395 } else {
425- return hetero_sample_random< false , false >(
396+ return hetero_sample< false , false , false >(
426397 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
427- num_neighbors_dict, num_hops);
398+ num_neighbors_dict, node_time_dict, num_hops);
428399 }
429400}
430401
431402tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
432403 c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
433- hetero_neighbor_temporal_sample_cpu (
404+ hetero_temporal_neighbor_sample_cpu (
434405 const vector<node_t > &node_types, const vector<edge_t > &edge_types,
435406 const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
436407 const c10::Dict<rel_t , torch::Tensor> &row_dict,
@@ -442,18 +413,18 @@ hetero_neighbor_temporal_sample_cpu(
442413 if (replace && directed) {
443414 return hetero_sample<true , true , true >(
444415 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
445- num_neighbors_dict, num_hops, node_time_dict );
416+ num_neighbors_dict, node_time_dict, num_hops );
446417 } else if (replace && !directed) {
447418 return hetero_sample<true , false , true >(
448419 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
449- num_neighbors_dict, num_hops, node_time_dict );
420+ num_neighbors_dict, node_time_dict, num_hops );
450421 } else if (!replace && directed) {
451422 return hetero_sample<false , true , true >(
452423 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
453- num_neighbors_dict, num_hops, node_time_dict );
424+ num_neighbors_dict, node_time_dict, num_hops );
454425 } else {
455426 return hetero_sample<false , false , true >(
456427 node_types, edge_types, colptr_dict, row_dict, input_node_dict,
457- num_neighbors_dict, num_hops, node_time_dict );
428+ num_neighbors_dict, node_time_dict, num_hops );
458429 }
459430}
0 commit comments