Skip to content

Commit ff3be8e

Browse files
committed
pytorch 1.4 support: toIntVector -> to IntList
1 parent 7ef77d9 commit ff3be8e

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

csrc/scatter.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/script.h>
22

33
#include "cpu/scatter_cpu.h"
4+
#include "utils.h"
45

56
#ifdef WITH_CUDA
67
#include "cuda/scatter_cuda.h"
@@ -58,7 +59,7 @@ class ScatterSum : public torch::autograd::Function<ScatterSum> {
5859
auto saved = ctx->get_saved_variables();
5960
auto index = saved[0];
6061
auto dim = ctx->saved_data["dim"].toInt();
61-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
62+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
6263
auto grad_in = torch::gather(grad_out, dim, index, false);
6364
return {grad_in, Variable(), Variable(), Variable(), Variable()};
6465
}
@@ -100,7 +101,7 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
100101
auto index = saved[0];
101102
auto count = saved[1];
102103
auto dim = ctx->saved_data["dim"].toInt();
103-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
104+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
104105
count = torch::gather(count, dim, index, false);
105106
auto grad_in = torch::gather(grad_out, dim, index, false);
106107
grad_in.div_(count);
@@ -134,7 +135,7 @@ class ScatterMin : public torch::autograd::Function<ScatterMin> {
134135
auto index = saved[0];
135136
auto arg_out = saved[1];
136137
auto dim = ctx->saved_data["dim"].toInt();
137-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
138+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
138139
src_shape[dim] += 1;
139140
auto grad_in = torch::zeros(src_shape, grad_out.options());
140141
grad_in.scatter_(dim, arg_out, grad_out);
@@ -169,7 +170,7 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> {
169170
auto index = saved[0];
170171
auto arg_out = saved[1];
171172
auto dim = ctx->saved_data["dim"].toInt();
172-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
173+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
173174
src_shape[dim] += 1;
174175
auto grad_in = torch::zeros(src_shape, grad_out.options());
175176
grad_in.scatter_(dim, arg_out, grad_out);

csrc/segment_coo.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/script.h>
22

33
#include "cpu/segment_coo_cpu.h"
4+
#include "utils.h"
45

56
#ifdef WITH_CUDA
67
#include "cuda/segment_coo_cuda.h"
@@ -57,7 +58,7 @@ class SegmentSumCOO : public torch::autograd::Function<SegmentSumCOO> {
5758
auto grad_out = grad_outs[0];
5859
auto saved = ctx->get_saved_variables();
5960
auto index = saved[0];
60-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
61+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
6162
auto grad_in = torch::empty(src_shape, grad_out.options());
6263
gather_coo_fw(grad_out, index, grad_in);
6364
return {grad_in, Variable(), Variable(), Variable()};
@@ -85,7 +86,7 @@ class SegmentMeanCOO : public torch::autograd::Function<SegmentMeanCOO> {
8586
auto saved = ctx->get_saved_variables();
8687
auto index = saved[0];
8788
auto count = saved[1];
88-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
89+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
8990
auto grad_in = torch::empty(src_shape, grad_out.options());
9091
gather_coo_fw(grad_out, index, grad_in);
9192
count = gather_coo_fw(count, index, torch::nullopt);
@@ -118,7 +119,7 @@ class SegmentMinCOO : public torch::autograd::Function<SegmentMinCOO> {
118119
auto saved = ctx->get_saved_variables();
119120
auto index = saved[0];
120121
auto arg_out = saved[1];
121-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
122+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
122123
src_shape[index.dim() - 1] += 1;
123124
auto grad_in = torch::zeros(src_shape, grad_out.options());
124125
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
@@ -150,7 +151,7 @@ class SegmentMaxCOO : public torch::autograd::Function<SegmentMaxCOO> {
150151
auto saved = ctx->get_saved_variables();
151152
auto index = saved[0];
152153
auto arg_out = saved[1];
153-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
154+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
154155
src_shape[index.dim() - 1] += 1;
155156
auto grad_in = torch::zeros(src_shape, grad_out.options());
156157
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
@@ -177,7 +178,7 @@ class GatherCOO : public torch::autograd::Function<GatherCOO> {
177178
auto grad_out = grad_outs[0];
178179
auto saved = ctx->get_saved_variables();
179180
auto index = saved[0];
180-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
181+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
181182

182183
auto grad_in = torch::zeros(src_shape, grad_out.options());
183184
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");

csrc/segment_csr.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/script.h>
22

33
#include "cpu/segment_csr_cpu.h"
4+
#include "utils.h"
45

56
#ifdef WITH_CUDA
67
#include "cuda/segment_csr_cuda.h"
@@ -55,7 +56,7 @@ class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
5556
auto grad_out = grad_outs[0];
5657
auto saved = ctx->get_saved_variables();
5758
auto indptr = saved[0];
58-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
59+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
5960
auto grad_in = torch::empty(src_shape, grad_out.options());
6061
gather_csr_fw(grad_out, indptr, grad_in);
6162
return {grad_in, Variable(), Variable()};
@@ -79,7 +80,7 @@ class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
7980
auto grad_out = grad_outs[0];
8081
auto saved = ctx->get_saved_variables();
8182
auto indptr = saved[0];
82-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
83+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
8384
auto grad_in = torch::empty(src_shape, grad_out.options());
8485
gather_csr_fw(grad_out, indptr, grad_in);
8586
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
@@ -114,7 +115,7 @@ class SegmentMinCSR : public torch::autograd::Function<SegmentMinCSR> {
114115
auto saved = ctx->get_saved_variables();
115116
auto indptr = saved[0];
116117
auto arg_out = saved[1];
117-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
118+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
118119
src_shape[indptr.dim() - 1] += 1;
119120
auto grad_in = torch::zeros(src_shape, grad_out.options());
120121
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
@@ -145,7 +146,7 @@ class SegmentMaxCSR : public torch::autograd::Function<SegmentMaxCSR> {
145146
auto saved = ctx->get_saved_variables();
146147
auto indptr = saved[0];
147148
auto arg_out = saved[1];
148-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
149+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
149150
src_shape[indptr.dim() - 1] += 1;
150151
auto grad_in = torch::zeros(src_shape, grad_out.options());
151152
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
@@ -172,7 +173,7 @@ class GatherCSR : public torch::autograd::Function<GatherCSR> {
172173
auto grad_out = grad_outs[0];
173174
auto saved = ctx->get_saved_variables();
174175
auto indptr = saved[0];
175-
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
176+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
176177

177178
auto grad_in = torch::empty(src_shape, grad_out.options());
178179
segment_csr_fw(grad_out, indptr, grad_in, "sum");

csrc/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <torch/script.h>
4+
#include <vector>
5+
6+
inline std::vector<int64_t> list2vec(const c10::List<int64_t> list) {
7+
std::vector<int64_t> result;
8+
result.reserve(list.size());
9+
for (size_t i = 0; i < list.size(); i++)
10+
result.push_back(list[i]);
11+
return result;
12+
}

0 commit comments

Comments
 (0)