@@ -114,30 +114,41 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114114 from_vector<int64_t >(cols), from_vector<int64_t >(edges));
115115}
116116
117- template <bool replace, bool directed>
117+ bool satisfy_time_constraint (const c10::Dict<node_t , torch::Tensor> &node_time_dict,
118+ const std::string &src_node_type,
119+ const int64_t &dst_time,
120+ const int64_t &sampled_node) {
121+ // whether src -> dst obeys the time constraint
122+ try {
123+ const auto *src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124+ return dst_time < src_time[sampled_node];
125+ }
126+ catch (int err) {
127+ // if the node type does not have timestamp, fall back to normal sampling
128+ return true ;
129+ }
130+ }
131+
132+
133+ template <bool replace, bool directed, bool temporal>
118134tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
119135 c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
120136hetero_sample (const vector<node_t > &node_types,
121- const vector<edge_t > &edge_types,
122- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
123- const c10::Dict<rel_t , torch::Tensor> &row_dict,
124- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
125- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
126- const int64_t num_hops) {
137+ const vector<edge_t > &edge_types,
138+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
139+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
140+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
141+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
142+ const int64_t num_hops,
143+ const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
144+ // bool temporal = (!node_time_dict.empty());
127145
128146 // Create a mapping to convert single string relations to edge type triplets:
129147 unordered_map<rel_t , edge_t > to_edge_type;
130148 for (const auto &k : edge_types)
131149 to_edge_type[get<0 >(k) + " __" + get<1 >(k) + " __" + get<2 >(k)] = k;
132150
133151 // Initialize some data structures for the sampling process:
134- unordered_map<node_t , vector<int64_t >> samples_dict;
135- unordered_map<node_t , unordered_map<int64_t , int64_t >> to_local_node_dict;
136- for (const auto &node_type : node_types) {
137- samples_dict[node_type];
138- to_local_node_dict[node_type];
139- }
140-
141152 unordered_map<rel_t , vector<int64_t >> rows_dict, cols_dict, edges_dict;
142153 for (const auto &kv : colptr_dict) {
143154 const auto &rel_type = kv.key ();
@@ -146,18 +157,40 @@ hetero_sample(const vector<node_t> &node_types,
146157 edges_dict[rel_type];
147158 }
148159
160+ unordered_map<node_t , vector<int64_t >> samples_dict;
161+ unordered_map<node_t , unordered_map<int64_t , int64_t >> to_local_node_dict;
162+ // The timestamp of the center node whose neighborhood that the sampled node
163+ // belongs to. It maps node_type to empty vector in non-temporal sampling.
164+ unordered_map<node_t , vector<int64_t >> root_time_dict;
165+ for (const auto &node_type : node_types) {
166+ samples_dict[node_type];
167+ to_local_node_dict[node_type];
168+ root_time_dict[node_type];
169+ }
170+
149171 // Add the input nodes to the output nodes:
150172 for (const auto &kv : input_node_dict) {
151173 const auto &node_type = kv.key ();
152174 const torch::Tensor &input_node = kv.value ();
153175 const auto *input_node_data = input_node.data_ptr <int64_t >();
176+ // dummy value. will be reset to root time if is_temporal==true
177+ auto *node_time_data = input_node.data_ptr <int64_t >();
178+ // root_time[i] stores the timestamp of the computation tree root
179+ // of the node samples[i]
180+ if (temporal) {
181+ node_time_data = node_time_dict.at (node_type).data_ptr <int64_t >();
182+ }
154183
155184 auto &samples = samples_dict.at (node_type);
156185 auto &to_local_node = to_local_node_dict.at (node_type);
186+ auto &root_time = root_time_dict.at (node_type);
157187 for (int64_t i = 0 ; i < input_node.numel (); i++) {
158188 const auto &v = input_node_data[i];
159189 samples.push_back (v);
160190 to_local_node.insert ({v, i});
191+ if (temporal) {
192+ root_time.push_back (node_time_data[v]);
193+ }
161194 }
162195 }
163196
@@ -187,8 +220,17 @@ hetero_sample(const vector<node_t> &node_types,
187220
188221 const auto &begin = slice_dict.at (dst_node_type).first ;
189222 const auto &end = slice_dict.at (dst_node_type).second ;
223+ if (begin == end){
224+ continue ;
225+ }
226+ // for temporal sampling, sampled src node cannot have timestamp greater
227+ // than its corresponding dst_root_time
228+ const auto &dst_root_time = root_time_dict.at (dst_node_type);
229+ auto &src_root_time = root_time_dict.at (src_node_type);
230+
190231 for (int64_t i = begin; i < end; i++) {
191232 const auto &w = dst_samples[i];
233+ const auto &dst_time = dst_root_time[i];
192234 const auto &col_start = colptr_data[w];
193235 const auto &col_end = colptr_data[w + 1 ];
194236 const auto col_count = col_end - col_start;
@@ -197,31 +239,56 @@ hetero_sample(const vector<node_t> &node_types,
197239 continue ;
198240
199241 if ((num_samples < 0 ) || (!replace && (num_samples >= col_count))) {
242+ // select all neighbors
200243 for (int64_t offset = col_start; offset < col_end; offset++) {
201244 const int64_t &v = row_data[offset];
245+ bool time_constraint = true ;
246+ if (temporal) {
247+ time_constraint = satisfy_time_constraint (
248+ node_time_dict, src_node_type, dst_time, v);
249+ }
250+ if (!time_constraint)
251+ continue ;
202252 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
203- if (res.second )
253+ if (res.second ) {
204254 src_samples.push_back (v);
255+ if (temporal)
256+ src_root_time.push_back (dst_time);
257+ }
205258 if (directed) {
206259 cols.push_back (i);
207260 rows.push_back (res.first ->second );
208261 edges.push_back (offset);
209262 }
210263 }
211264 } else if (replace) {
212- for (int64_t j = 0 ; j < num_samples; j++) {
265+ // sample with replacement
266+ int64_t num_neighbors = 0 ;
267+ while (num_neighbors < num_samples) {
213268 const int64_t offset = col_start + uniform_randint (col_count);
214269 const int64_t &v = row_data[offset];
270+ bool time_constraint = true ;
271+ if (temporal) {
272+ time_constraint = satisfy_time_constraint (
273+ node_time_dict, src_node_type, dst_time, v);
274+ }
275+ if (!time_constraint)
276+ continue ;
215277 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
216- if (res.second )
278+ if (res.second ) {
217279 src_samples.push_back (v);
280+ if (temporal)
281+ src_root_time.push_back (dst_time);
282+ }
218283 if (directed) {
219284 cols.push_back (i);
220285 rows.push_back (res.first ->second );
221286 edges.push_back (offset);
222287 }
288+ num_neighbors += 1 ;
223289 }
224290 } else {
291+ // sample without replacement
225292 unordered_set<int64_t > rnd_indices;
226293 for (int64_t j = col_count - num_samples; j < col_count; j++) {
227294 int64_t rnd = uniform_randint (j);
@@ -231,9 +298,19 @@ hetero_sample(const vector<node_t> &node_types,
231298 }
232299 const int64_t offset = col_start + rnd;
233300 const int64_t &v = row_data[offset];
301+ bool time_constraint = true ;
302+ if (temporal) {
303+ time_constraint = satisfy_time_constraint (
304+ node_time_dict, src_node_type, dst_time, v);
305+ }
306+ if (!time_constraint)
307+ continue ;
234308 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
235- if (res.second )
309+ if (res.second ) {
236310 src_samples.push_back (v);
311+ if (temporal)
312+ src_root_time.push_back (dst_time);
313+ }
237314 if (directed) {
238315 cols.push_back (i);
239316 rows.push_back (res.first ->second );
@@ -290,6 +367,27 @@ hetero_sample(const vector<node_t> &node_types,
290367 from_vector<rel_t , int64_t >(edges_dict));
291368}
292369
370+ template <bool replace, bool directed>
371+ tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
372+ c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
373+ hetero_sample_random (const vector<node_t > &node_types,
374+ const vector<edge_t > &edge_types,
375+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
376+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
377+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
378+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
379+ const int64_t num_hops) {
380+ c10::Dict<node_t , torch::Tensor> empty_dict;
381+ return hetero_sample<replace, directed, false >(node_types,
382+ edge_types,
383+ colptr_dict,
384+ row_dict,
385+ input_node_dict,
386+ num_neighbors_dict,
387+ num_hops,
388+ empty_dict);
389+ }
390+
293391} // namespace
294392
295393tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -320,20 +418,58 @@ hetero_neighbor_sample_cpu(
320418 const int64_t num_hops, const bool replace, const bool directed) {
321419
322420 if (replace && directed) {
323- return hetero_sample<true , true >(node_types, edge_types, colptr_dict,
324- row_dict, input_node_dict,
325- num_neighbors_dict, num_hops);
421+ return hetero_sample_random<true , true >(
422+ node_types, edge_types, colptr_dict,
423+ row_dict, input_node_dict,
424+ num_neighbors_dict, num_hops);
425+ } else if (replace && !directed) {
426+ return hetero_sample_random<true , false >(
427+ node_types, edge_types, colptr_dict,
428+ row_dict, input_node_dict,
429+ num_neighbors_dict, num_hops);
430+ } else if (!replace && directed) {
431+ return hetero_sample_random<false , true >(
432+ node_types, edge_types, colptr_dict,
433+ row_dict, input_node_dict,
434+ num_neighbors_dict, num_hops);
435+ } else {
436+ return hetero_sample_random<false , false >(
437+ node_types, edge_types, colptr_dict,
438+ row_dict, input_node_dict,
439+ num_neighbors_dict, num_hops);
440+ }
441+ }
442+
443+ tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
444+ c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
445+ hetero_neighbor_temporal_sample_cpu (
446+ const vector<node_t > &node_types, const vector<edge_t > &edge_types,
447+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
448+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
449+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
450+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
451+ const c10::Dict<node_t , torch::Tensor> &node_time_dict,
452+ const int64_t num_hops, const bool replace, const bool directed) {
453+
454+ if (replace && directed) {
455+ return hetero_sample<true , true , true >(
456+ node_types, edge_types, colptr_dict,
457+ row_dict, input_node_dict,
458+ num_neighbors_dict, num_hops, node_time_dict);
326459 } else if (replace && !directed) {
327- return hetero_sample<true , false >(node_types, edge_types, colptr_dict,
328- row_dict, input_node_dict,
329- num_neighbors_dict, num_hops);
460+ return hetero_sample<true , false , true >(
461+ node_types, edge_types, colptr_dict,
462+ row_dict, input_node_dict,
463+ num_neighbors_dict, num_hops, node_time_dict);
330464 } else if (!replace && directed) {
331- return hetero_sample<false , true >(node_types, edge_types, colptr_dict,
332- row_dict, input_node_dict,
333- num_neighbors_dict, num_hops);
465+ return hetero_sample<false , true , true >(
466+ node_types, edge_types, colptr_dict,
467+ row_dict, input_node_dict,
468+ num_neighbors_dict, num_hops, node_time_dict);
334469 } else {
335- return hetero_sample<false , false >(node_types, edge_types, colptr_dict,
336- row_dict, input_node_dict,
337- num_neighbors_dict, num_hops);
470+ return hetero_sample<false , false , true >(
471+ node_types, edge_types, colptr_dict,
472+ row_dict, input_node_dict,
473+ num_neighbors_dict, num_hops, node_time_dict);
338474 }
339475}
0 commit comments