@@ -10,6 +10,8 @@ using namespace std;
1010
1111namespace {
1212
13+ typedef phmap::flat_hash_map<pair<int64_t , int64_t >, int64_t > temporarl_edge_dict;
14+
1315template <bool replace, bool directed>
1416tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
1517sample (const torch::Tensor &colptr, const torch::Tensor &row,
@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
146148
147149 // Initialize some data structures for the sampling process:
148150 phmap::flat_hash_map<node_t , vector<int64_t >> samples_dict;
151+ phmap::flat_hash_map<node_t , vector<pair<int64_t , int64_t >>> temp_samples_dict;
149152 phmap::flat_hash_map<node_t , phmap::flat_hash_map<int64_t , int64_t >> to_local_node_dict;
153+ phmap::flat_hash_map<node_t , temporarl_edge_dict> temp_to_local_node_dict;
150154 phmap::flat_hash_map<node_t , vector<int64_t >> root_time_dict;
151155 for (const auto &node_type : node_types) {
152156 samples_dict[node_type];
157+ temp_samples_dict[node_type];
153158 to_local_node_dict[node_type];
159+ temp_to_local_node_dict[node_type];
154160 root_time_dict[node_type];
155161 }
156162
@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
175181 }
176182
177183 auto &samples = samples_dict.at (node_type);
184+ auto &temp_samples = temp_samples_dict.at (node_type);
178185 auto &to_local_node = to_local_node_dict.at (node_type);
186+ auto &temp_to_local_node = temp_to_local_node_dict.at (node_type);
179187 auto &root_time = root_time_dict.at (node_type);
180188 for (int64_t i = 0 ; i < input_node.numel (); i++) {
181189 const auto &v = input_node_data[i];
182- samples.push_back (v);
183- to_local_node.insert ({v, i});
190+ if (temporal) {
191+ temp_samples.push_back ({v, i});
192+ temp_to_local_node.insert ({{v, i}, i});
193+ } else {
194+ samples.push_back (v);
195+ to_local_node.insert ({v, i});
196+ }
184197 if (temporal)
185198 root_time.push_back (node_time_data[v]);
186199 }
187200 }
188201
189202 phmap::flat_hash_map<node_t , pair<int64_t , int64_t >> slice_dict;
190- for (const auto &kv : samples_dict)
191- slice_dict[kv.first ] = {0 , kv.second .size ()};
203+ if (temporal) {
204+ for (const auto &kv : temp_samples_dict) {
205+ slice_dict[kv.first ] = {0 , kv.second .size ()};
206+ }
207+ } else {
208+ for (const auto &kv : samples_dict)
209+ slice_dict[kv.first ] = {0 , kv.second .size ()};
210+ }
192211
193212 vector<rel_t > all_rel_types;
194213 for (const auto &kv : num_neighbors_dict) {
@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
203222 const auto &dst_node_type = get<2 >(edge_type);
204223 const auto num_samples = num_neighbors_dict.at (rel_type)[ell];
205224 const auto &dst_samples = samples_dict.at (dst_node_type);
225+ const auto &temp_dst_samples = temp_samples_dict.at (dst_node_type);
206226 auto &src_samples = samples_dict.at (src_node_type);
227+ auto &temp_src_samples = temp_samples_dict.at (src_node_type);
207228 auto &to_local_src_node = to_local_node_dict.at (src_node_type);
229+ auto &temp_to_local_src_node = temp_to_local_node_dict.at (src_node_type);
208230
209231 const torch::Tensor &colptr = colptr_dict.at (rel_type);
210232 const auto *colptr_data = colptr.data_ptr <int64_t >();
@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
223245 const auto &begin = slice_dict.at (dst_node_type).first ;
224246 const auto &end = slice_dict.at (dst_node_type).second ;
225247 for (int64_t i = begin; i < end; i++) {
226- const auto &w = dst_samples[i];
248+ const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
249+ const int64_t root_w = temporal ? temp_dst_samples[i].second : -1 ;
227250 int64_t dst_time = 0 ;
228251 if (temporal)
229252 dst_time = dst_root_time[i];
@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
241264 if (temporal) {
242265 if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
243266 continue ;
244- // force disjoint of computation tree
267+ // force disjoint of computation tree based on source batch idx.
245268 // note that the sampling always needs to have directed=True
246269 // for temporal case
247270 // to_local_src_node is not used for temporal / directed case
248- const int64_t sample_idx = src_samples.size ();
249- src_samples.push_back (v);
250- src_root_time.push_back (dst_time);
271+ const auto res = temp_to_local_src_node.insert ({{v, root_w}, (int64_t )temp_src_samples.size ()});
272+ if (res.second ) {
273+ temp_src_samples.push_back ({v, root_w});
274+ src_root_time.push_back (dst_time);
275+ }
276+
251277 cols.push_back (i);
252- rows.push_back (sample_idx );
278+ rows.push_back (res. first -> second );
253279 edges.push_back (offset);
254280 } else {
255281 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
272298 // TODO Infinity loop if no neighbor satisfies time constraint:
273299 if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
274300 continue ;
275- // force disjoint of computation tree
301+ // force disjoint of computation tree based on source batch idx.
276302 // note that the sampling always needs to have directed=True
277303 // for temporal case
278- const int64_t sample_idx = src_samples.size ();
279- src_samples.push_back (v);
280- src_root_time.push_back (dst_time);
304+ const auto res = temp_to_local_src_node.insert ({{v, root_w}, (int64_t )temp_src_samples.size ()});
305+ if (res.second ) {
306+ temp_src_samples.push_back ({v, root_w});
307+ src_root_time.push_back (dst_time);
308+ }
309+
281310 cols.push_back (i);
282- rows.push_back (sample_idx );
311+ rows.push_back (res. first -> second );
283312 edges.push_back (offset);
284313 } else {
285314 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
307336 if (temporal) {
308337 if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
309338 continue ;
310- // force disjoint of computation tree
339+ // force disjoint of computation tree based on source batch idx.
311340 // note that the sampling always needs to have directed=True
312341 // for temporal case
313- const int64_t sample_idx = src_samples.size ();
314- src_samples.push_back (v);
315- src_root_time.push_back (dst_time);
342+ const auto res = temp_to_local_src_node.insert ({{v, root_w}, (int64_t )temp_src_samples.size ()});
343+ if (res.second ) {
344+ temp_src_samples.push_back ({v, root_w});
345+ src_root_time.push_back (dst_time);
346+ }
347+
316348 cols.push_back (i);
317- rows.push_back (sample_idx );
349+ rows.push_back (res. first -> second );
318350 edges.push_back (offset);
319351 } else {
320352 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
331363 }
332364 }
333365
334- for (const auto &kv : samples_dict) {
335- slice_dict[kv.first ] = {slice_dict.at (kv.first ).second , kv.second .size ()};
366+ if (temporal) {
367+ for (const auto &kv : temp_samples_dict) {
368+ slice_dict[kv.first ] = {0 , kv.second .size ()};
369+ }
370+ } else {
371+ for (const auto &kv : samples_dict)
372+ slice_dict[kv.first ] = {0 , kv.second .size ()};
336373 }
337374 }
338375
376+ // Temporal sample disable undirected
377+ assert (!(temporal && !directed));
339378 if (!directed) { // Construct the subgraph among the sampled nodes:
340379 phmap::flat_hash_map<int64_t , int64_t >::iterator iter;
341380 for (const auto &kv : colptr_dict) {
@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
371410 }
372411 }
373412
413+ // Construct samples dictionary from temporal sample dictionary.
414+ if (temporal) {
415+ for (const auto &kv : temp_samples_dict) {
416+ const auto &node_type = kv.first ;
417+ const auto &samples = kv.second ;
418+ samples_dict[node_type].reserve (samples.size ());
419+ for (const auto &v : samples) {
420+ samples_dict[node_type].push_back (v.first );
421+ }
422+ }
423+ }
424+
374425 return make_tuple (from_vector<node_t , int64_t >(samples_dict),
375426 from_vector<rel_t , int64_t >(rows_dict),
376427 from_vector<rel_t , int64_t >(cols_dict),
0 commit comments