Skip to content

Commit 5e2d0f1

Browse files
committed
scatter cpu:
1 parent 64772d7 commit 5e2d0f1

File tree

14 files changed

+570
-19
lines changed

14 files changed

+570
-19
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ All included operations are broadcastable, work on varying data types, and are i
4545

4646
## Installation
4747

48-
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
48+
Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
4949

5050
```
5151
$ python -c "import torch; print(torch.__version__)"
52-
>>> 1.1.0
52+
>>> 1.3.0
5353
5454
$ echo $PATH
5555
>>> /usr/local/cuda/bin:...

csrc/cpu/scatter_cpu.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "scatter_cpu.h"
2+
3+
#include "index_info.h"
4+
#include "reducer.h"
5+
#include "utils.h"
6+
7+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
8+
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
9+
torch::optional<torch::Tensor> optional_out,
10+
torch::optional<int64_t> dim_size, std::string reduce) {
11+
CHECK_CPU(src);
12+
CHECK_CPU(index);
13+
if (optional_out.has_value())
14+
CHECK_CPU(optional_out.value());
15+
16+
CHECK_INPUT(src.dim() == index.dim());
17+
for (auto i = 0; i < index.dim() - 1; i++)
18+
CHECK_INPUT(src.size(i) >= index.size(i));
19+
20+
if (dim < 0)
21+
dim = src.dim() + dim;
22+
23+
src = src.contiguous();
24+
25+
torch::Tensor out;
26+
if (optional_out.has_value()) {
27+
out = optional_out.value().contiguous();
28+
for (auto i = 0; i < out.dim(); i++)
29+
if (i != dim)
30+
CHECK_INPUT(src.size(i) == out.size(i));
31+
} else {
32+
auto sizes = src.sizes().vec();
33+
if (dim_size.has_value())
34+
sizes[dim] = dim_size.value();
35+
else
36+
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
37+
out = torch::empty(sizes, src.options());
38+
}
39+
40+
torch::optional<torch::Tensor> arg_out = torch::nullopt;
41+
int64_t *arg_out_data = nullptr;
42+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
43+
arg_out = torch::full_like(out, src.size(dim), index.options());
44+
arg_out_data = arg_out.value().data_ptr<int64_t>();
45+
}
46+
47+
auto B = 1;
48+
for (auto i = 0; i < dim; i++)
49+
B *= src.size(i);
50+
auto E = src.size(dim);
51+
auto K = src.numel() / (B * E);
52+
auto N = out.size(dim);
53+
54+
auto index_info = getTensorInfo<int64_t>(index);
55+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
56+
auto src_data = src.data_ptr<scalar_t>();
57+
auto out_data = out.data_ptr<scalar_t>();
58+
59+
int64_t i, idx;
60+
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
61+
if (!optional_out.has_value())
62+
out.fill_(Reducer<scalar_t, REDUCE>::init());
63+
64+
for (auto b = 0; b < B; b++) {
65+
for (auto e = 0; e < E; e++) {
66+
for (auto k = 0; k < K; k++) {
67+
i = b * E * K + e * K + k;
68+
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
69+
Reducer<scalar_t, REDUCE>::update(
70+
out_data + b * N * K + idx * K + k, src_data[i],
71+
arg_out_data + b * N * K + idx * K + k, e);
72+
}
73+
}
74+
}
75+
76+
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
77+
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
78+
});
79+
});
80+
81+
return std::make_tuple(out, arg_out);
82+
}

csrc/cpu/scatter_cpu.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
6+
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
7+
torch::optional<torch::Tensor> optional_out,
8+
torch::optional<int64_t> dim_size, std::string reduce);

csrc/cpu/segment_coo_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
1616
CHECK_INPUT(src.dim() >= index.dim());
1717

1818
auto sizes = index.sizes().vec();
19-
for (int i = 0; i < index.dim(); i++)
19+
for (auto i = 0; i < index.dim(); i++)
2020
sizes[i] = src.size(i);
2121
index = index.expand(sizes);
2222

@@ -27,7 +27,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
2727
torch::Tensor out;
2828
if (optional_out.has_value()) {
2929
out = optional_out.value().contiguous();
30-
for (int i = 0; i < out.dim(); i++)
30+
for (auto i = 0; i < out.dim(); i++)
3131
if (i != dim)
3232
CHECK_INPUT(src.size(i) == out.size(i));
3333
} else {

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
2727
torch::Tensor out;
2828
if (optional_out.has_value()) {
2929
out = optional_out.value().contiguous();
30-
for (int i = 0; i < out.dim(); i++)
30+
for (auto i = 0; i < out.dim(); i++)
3131
if (i != dim)
3232
CHECK_INPUT(src.size(i) == out.size(i));
3333
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
@@ -126,7 +126,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
126126

127127
std::vector<scalar_t> vals(K);
128128
int64_t row_start, row_end;
129-
for (int n = 0; n < N; n++) {
129+
for (auto n = 0; n < N; n++) {
130130
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
131131
row_start = indptr_info.data[offset];
132132
row_end = indptr_info.data[offset + stride];

csrc/cuda/reducer.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
106106
atomMul(address, val);
107107
else if (REDUCE == DIV)
108108
atomDiv(address, val);
109-
else if (REDUCE == MIN && val < *address)
109+
else if (REDUCE == MIN)
110110
atomMin(address, val);
111-
else if (REDUCE == MAX && val > *address)
111+
else if (REDUCE == MAX)
112112
atomMax(address, val);
113113
}
114114
};

csrc/cuda/scatter_cuda.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "scatter_cuda.h"
2+
3+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
4+
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
5+
torch::optional<torch::Tensor> optional_out,
6+
torch::optional<int64_t> dim_size, std::string reduce) {
7+
return std::make_tuple(src, optional_out);
8+
}

csrc/cuda/scatter_cuda.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
6+
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
7+
torch::optional<torch::Tensor> optional_out,
8+
torch::optional<int64_t> dim_size, std::string reduce);

csrc/scatter.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#include <torch/script.h>
2+
3+
#include "cpu/scatter_cpu.h"
4+
5+
#ifdef WITH_CUDA
6+
#include "cuda/scatter_cuda.h"
7+
#endif
8+
9+
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
10+
if (dim < 0)
11+
dim = other.dim() + dim;
12+
if (src.dim() == 1)
13+
for (auto i = 0; i < dim; i++)
14+
src = src.unsqueeze(0);
15+
for (auto i = src.dim(); i < other.dim(); i++)
16+
src = src.unsqueeze(-1);
17+
src = src.expand(other.sizes().vec());
18+
return src;
19+
}
20+
21+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
22+
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
23+
torch::optional<torch::Tensor> optional_out,
24+
torch::optional<int64_t> dim_size, std::string reduce) {
25+
if (src.device().is_cuda()) {
26+
#ifdef WITH_CUDA
27+
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
28+
#else
29+
AT_ERROR("Not compiled with CUDA support");
30+
#endif
31+
} else {
32+
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
33+
}
34+
}
35+
using torch::autograd::AutogradContext;
36+
using torch::autograd::Variable;
37+
using torch::autograd::variable_list;
38+
39+
class ScatterSum : public torch::autograd::Function<ScatterSum> {
40+
public:
41+
static variable_list forward(AutogradContext *ctx, Variable src,
42+
Variable index, int64_t dim,
43+
torch::optional<Variable> optional_out,
44+
torch::optional<int64_t> dim_size) {
45+
ctx->saved_data["dim"] = dim;
46+
ctx->saved_data["src_shape"] = src.sizes();
47+
index = broadcast(index, src, dim);
48+
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
49+
auto out = std::get<0>(result);
50+
ctx->save_for_backward({index});
51+
if (optional_out.has_value())
52+
ctx->mark_dirty({optional_out.value()});
53+
return {out};
54+
}
55+
56+
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
57+
auto grad_out = grad_outs[0];
58+
auto saved = ctx->get_saved_variables();
59+
auto index = saved[0];
60+
auto dim = ctx->saved_data["dim"].toInt();
61+
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
62+
auto grad_in = torch::gather(grad_out, dim, index, false);
63+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
64+
}
65+
};
66+
67+
class ScatterMean : public torch::autograd::Function<ScatterMean> {
68+
public:
69+
static variable_list forward(AutogradContext *ctx, Variable src,
70+
Variable index, int64_t dim,
71+
torch::optional<Variable> optional_out,
72+
torch::optional<int64_t> dim_size) {
73+
ctx->saved_data["dim"] = dim;
74+
ctx->saved_data["src_shape"] = src.sizes();
75+
76+
auto old_index = index;
77+
78+
index = broadcast(index, src, dim);
79+
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
80+
auto out = std::get<0>(result);
81+
82+
auto ones = torch::ones(old_index.sizes(), src.options());
83+
result = scatter_fw(ones, old_index,
84+
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
85+
torch::nullopt, out.size(dim), "sum");
86+
auto count = std::get<0>(result);
87+
count.clamp_(1);
88+
count = broadcast(count, out, dim);
89+
out.div_(count);
90+
91+
ctx->save_for_backward({index, count});
92+
if (optional_out.has_value())
93+
ctx->mark_dirty({optional_out.value()});
94+
return {out};
95+
}
96+
97+
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
98+
auto grad_out = grad_outs[0];
99+
auto saved = ctx->get_saved_variables();
100+
auto index = saved[0];
101+
auto count = saved[1];
102+
auto dim = ctx->saved_data["dim"].toInt();
103+
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
104+
count = torch::gather(count, dim, index, false);
105+
auto grad_in = torch::gather(grad_out, dim, index, false);
106+
grad_in.div_(count);
107+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
108+
}
109+
};
110+
111+
class ScatterMin : public torch::autograd::Function<ScatterMin> {
112+
public:
113+
static variable_list forward(AutogradContext *ctx, Variable src,
114+
Variable index, int64_t dim,
115+
torch::optional<Variable> optional_out,
116+
torch::optional<int64_t> dim_size) {
117+
ctx->saved_data["dim"] = dim;
118+
ctx->saved_data["src_shape"] = src.sizes();
119+
120+
index = broadcast(index, src, dim);
121+
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
122+
auto out = std::get<0>(result);
123+
auto arg_out = std::get<1>(result).value();
124+
ctx->save_for_backward({index, arg_out});
125+
ctx->mark_non_differentiable({arg_out});
126+
if (optional_out.has_value())
127+
ctx->mark_dirty({optional_out.value()});
128+
return {out, arg_out};
129+
}
130+
131+
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
132+
auto grad_out = grad_outs[0];
133+
auto saved = ctx->get_saved_variables();
134+
auto index = saved[0];
135+
auto arg_out = saved[1];
136+
auto dim = ctx->saved_data["dim"].toInt();
137+
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
138+
src_shape[dim] += 1;
139+
auto grad_in = torch::zeros(src_shape, grad_out.options());
140+
grad_in.scatter_(dim, arg_out, grad_out);
141+
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
142+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
143+
}
144+
};
145+
146+
class ScatterMax : public torch::autograd::Function<ScatterMax> {
147+
public:
148+
static variable_list forward(AutogradContext *ctx, Variable src,
149+
Variable index, int64_t dim,
150+
torch::optional<Variable> optional_out,
151+
torch::optional<int64_t> dim_size) {
152+
ctx->saved_data["dim"] = dim;
153+
ctx->saved_data["src_shape"] = src.sizes();
154+
155+
index = broadcast(index, src, dim);
156+
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
157+
auto out = std::get<0>(result);
158+
auto arg_out = std::get<1>(result).value();
159+
ctx->save_for_backward({index, arg_out});
160+
ctx->mark_non_differentiable({arg_out});
161+
if (optional_out.has_value())
162+
ctx->mark_dirty({optional_out.value()});
163+
return {out, arg_out};
164+
}
165+
166+
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
167+
auto grad_out = grad_outs[0];
168+
auto saved = ctx->get_saved_variables();
169+
auto index = saved[0];
170+
auto arg_out = saved[1];
171+
auto dim = ctx->saved_data["dim"].toInt();
172+
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
173+
src_shape[dim] += 1;
174+
auto grad_in = torch::zeros(src_shape, grad_out.options());
175+
grad_in.scatter_(dim, arg_out, grad_out);
176+
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
177+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
178+
}
179+
};
180+
181+
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
182+
torch::optional<torch::Tensor> optional_out,
183+
torch::optional<int64_t> dim_size) {
184+
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
185+
}
186+
187+
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
188+
torch::optional<torch::Tensor> optional_out,
189+
torch::optional<int64_t> dim_size) {
190+
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
191+
}
192+
193+
std::tuple<torch::Tensor, torch::Tensor>
194+
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
195+
torch::optional<torch::Tensor> optional_out,
196+
torch::optional<int64_t> dim_size) {
197+
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
198+
return std::make_tuple(result[0], result[1]);
199+
}
200+
201+
std::tuple<torch::Tensor, torch::Tensor>
202+
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
203+
torch::optional<torch::Tensor> optional_out,
204+
torch::optional<int64_t> dim_size) {
205+
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
206+
return std::make_tuple(result[0], result[1]);
207+
}
208+
209+
static auto registry = torch::RegisterOperators()
210+
.op("torch_scatter::scatter_sum", &scatter_sum)
211+
.op("torch_scatter::scatter_mean", &scatter_mean)
212+
.op("torch_scatter::scatter_min", &scatter_min)
213+
.op("torch_scatter::scatter_max", &scatter_max);

0 commit comments

Comments
 (0)