Skip to content

Commit f056396

Browse files
committed
fixed negative dims
1 parent d3aabdf commit f056396

File tree

3 files changed

+3
-8
lines changed

3 files changed

+3
-8
lines changed

csrc/cpu/scatter_cpu.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
1717
for (auto i = 0; i < index.dim() - 1; i++)
1818
CHECK_INPUT(src.size(i) >= index.size(i));
1919

20-
if (dim < 0)
21-
dim = src.dim() + dim;
22-
2320
src = src.contiguous();
2421

2522
torch::Tensor out;

csrc/cuda/scatter_cuda.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
6969
for (auto i = 0; i < index.dim() - 1; i++)
7070
CHECK_INPUT(src.size(i) >= index.size(i));
7171

72-
if (dim < 0)
73-
dim = src.dim() + dim;
74-
7572
src = src.contiguous();
7673

7774
torch::Tensor out;

csrc/scatter.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#endif
99

1010
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
11-
if (dim < 0)
12-
dim = other.dim() + dim;
1311
if (src.dim() == 1)
1412
for (auto i = 0; i < dim; i++)
1513
src = src.unsqueeze(0);
@@ -43,6 +41,7 @@ class ScatterSum : public torch::autograd::Function<ScatterSum> {
4341
Variable index, int64_t dim,
4442
torch::optional<Variable> optional_out,
4543
torch::optional<int64_t> dim_size) {
44+
dim = dim < 0 ? src.dim() + dim : dim;
4645
ctx->saved_data["dim"] = dim;
4746
ctx->saved_data["src_shape"] = src.sizes();
4847
index = broadcast(index, src, dim);
@@ -116,6 +115,7 @@ class ScatterMin : public torch::autograd::Function<ScatterMin> {
116115
Variable index, int64_t dim,
117116
torch::optional<Variable> optional_out,
118117
torch::optional<int64_t> dim_size) {
118+
dim = dim < 0 ? src.dim() + dim : dim;
119119
ctx->saved_data["dim"] = dim;
120120
ctx->saved_data["src_shape"] = src.sizes();
121121

@@ -151,6 +151,7 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> {
151151
Variable index, int64_t dim,
152152
torch::optional<Variable> optional_out,
153153
torch::optional<int64_t> dim_size) {
154+
dim = dim < 0 ? src.dim() + dim : dim;
154155
ctx->saved_data["dim"] = dim;
155156
ctx->saved_data["src_shape"] = src.sizes();
156157

0 commit comments

Comments
 (0)