|
8 | 8 | #endif |
9 | 9 |
|
10 | 10 | torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { |
11 | | - if (dim < 0) |
12 | | - dim = other.dim() + dim; |
13 | 11 | if (src.dim() == 1) |
14 | 12 | for (auto i = 0; i < dim; i++) |
15 | 13 | src = src.unsqueeze(0); |
@@ -43,6 +41,7 @@ class ScatterSum : public torch::autograd::Function<ScatterSum> { |
43 | 41 | Variable index, int64_t dim, |
44 | 42 | torch::optional<Variable> optional_out, |
45 | 43 | torch::optional<int64_t> dim_size) { |
| 44 | + dim = dim < 0 ? src.dim() + dim : dim; |
46 | 45 | ctx->saved_data["dim"] = dim; |
47 | 46 | ctx->saved_data["src_shape"] = src.sizes(); |
48 | 47 | index = broadcast(index, src, dim); |
@@ -116,6 +115,7 @@ class ScatterMin : public torch::autograd::Function<ScatterMin> { |
116 | 115 | Variable index, int64_t dim, |
117 | 116 | torch::optional<Variable> optional_out, |
118 | 117 | torch::optional<int64_t> dim_size) { |
| 118 | + dim = dim < 0 ? src.dim() + dim : dim; |
119 | 119 | ctx->saved_data["dim"] = dim; |
120 | 120 | ctx->saved_data["src_shape"] = src.sizes(); |
121 | 121 |
|
@@ -151,6 +151,7 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> { |
151 | 151 | Variable index, int64_t dim, |
152 | 152 | torch::optional<Variable> optional_out, |
153 | 153 | torch::optional<int64_t> dim_size) { |
| 154 | + dim = dim < 0 ? src.dim() + dim : dim; |
154 | 155 | ctx->saved_data["dim"] = dim; |
155 | 156 | ctx->saved_data["src_shape"] = src.sizes(); |
156 | 157 |
|
|
0 commit comments