Skip to content

Commit 3ec1eac

Browse files
Fix issue 193 (#203)
* Fixes for build with the new pytorch * Fixes for build with the new pytorch * Fixes for build with the new pytorch
1 parent d987d29 commit 3ec1eac

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ hetero_sample(const vector<node_t> &node_types,
153153
// Add the input nodes to the output nodes:
154154
for (const auto &kv : input_node_dict) {
155155
const auto &node_type = kv.key();
156-
const auto &input_node = kv.value();
156+
const torch::Tensor &input_node = kv.value();
157157
const auto *input_node_data = input_node.data_ptr<int64_t>();
158158

159159
auto &samples = samples_dict.at(node_type);
@@ -180,8 +180,8 @@ hetero_sample(const vector<node_t> &node_types,
180180
auto &src_samples = samples_dict.at(src_node_type);
181181
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
182182

183-
const auto *colptr_data = colptr_dict.at(rel_type).data_ptr<int64_t>();
184-
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
183+
const auto *colptr_data = ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
184+
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
185185

186186
auto &rows = rows_dict.at(rel_type);
187187
auto &cols = cols_dict.at(rel_type);
@@ -261,8 +261,8 @@ hetero_sample(const vector<node_t> &node_types,
261261
const auto &dst_samples = samples_dict.at(dst_node_type);
262262
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
263263

264-
const auto *colptr_data = kv.value().data_ptr<int64_t>();
265-
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
264+
const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
265+
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
266266

267267
auto &rows = rows_dict.at(rel_type);
268268
auto &cols = cols_dict.at(rel_type);

0 commit comments

Comments
 (0)