@@ -153,7 +153,7 @@ hetero_sample(const vector<node_t> &node_types,
153153 // Add the input nodes to the output nodes:
154154 for (const auto &kv : input_node_dict) {
155155 const auto &node_type = kv.key ();
156- const auto &input_node = kv.value ();
156+ const torch::Tensor &input_node = kv.value ();
157157 const auto *input_node_data = input_node.data_ptr <int64_t >();
158158
159159 auto &samples = samples_dict.at (node_type);
@@ -180,8 +180,8 @@ hetero_sample(const vector<node_t> &node_types,
180180 auto &src_samples = samples_dict.at (src_node_type);
181181 auto &to_local_src_node = to_local_node_dict.at (src_node_type);
182182
183- const auto *colptr_data = colptr_dict.at (rel_type).data_ptr <int64_t >();
184- const auto *row_data = row_dict.at (rel_type).data_ptr <int64_t >();
183+ const auto *colptr_data = ((torch::Tensor) colptr_dict.at (rel_type) ).data_ptr <int64_t >();
184+ const auto *row_data = ((torch::Tensor) row_dict.at (rel_type) ).data_ptr <int64_t >();
185185
186186 auto &rows = rows_dict.at (rel_type);
187187 auto &cols = cols_dict.at (rel_type);
@@ -261,8 +261,8 @@ hetero_sample(const vector<node_t> &node_types,
261261 const auto &dst_samples = samples_dict.at (dst_node_type);
262262 auto &to_local_src_node = to_local_node_dict.at (src_node_type);
263263
264- const auto *colptr_data = kv.value ().data_ptr <int64_t >();
265- const auto *row_data = row_dict.at (rel_type).data_ptr <int64_t >();
264+ const auto *colptr_data = ((torch::Tensor) kv.value () ).data_ptr <int64_t >();
265+ const auto *row_data = ((torch::Tensor) row_dict.at (rel_type) ).data_ptr <int64_t >();
266266
267267 auto &rows = rows_dict.at (rel_type);
268268 auto &cols = cols_dict.at (rel_type);
0 commit comments