Skip to content

Commit 6281557

Browse files
committed
moved extensions to torch.ops
1 parent 0a221ab commit 6281557

22 files changed

+255
-216
lines changed

cpu/dim_apply.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
2020
\
2121
auto dims = TENSOR1.dim(); \
22-
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
22+
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
2323
auto counter = zeros.DATA_PTR<int64_t>(); \
2424
bool has_finished = false; \
2525
\
@@ -78,7 +78,7 @@
7878
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
7979
\
8080
auto dims = TENSOR1.dim(); \
81-
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
81+
auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
8282
auto counter = zeros.DATA_PTR<int64_t>(); \
8383
bool has_finished = false; \
8484
\

cpu/gather.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
#include <torch/extension.h>
1+
#include <torch/script.h>
22

33
#include "compat.h"
44
#include "index_info.h"
55

66
#include <vector>
77

8-
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
8+
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
99

10-
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
11-
at::optional<at::Tensor> out_opt) {
10+
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
11+
torch::optional<torch::Tensor> out_opt) {
1212
CHECK_CPU(src);
1313
CHECK_CPU(indptr);
1414
if (out_opt.has_value())
@@ -23,7 +23,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
2323
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
2424
"Input mismatch");
2525

26-
at::Tensor out;
26+
torch::Tensor out;
2727
if (out_opt.has_value()) {
2828
out = out_opt.value().contiguous();
2929
for (int i = 0; i < out.dim(); i++)
@@ -32,7 +32,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
3232
} else {
3333
auto sizes = src.sizes().vec();
3434
sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
35-
out = at::empty(sizes, src.options());
35+
out = torch::empty(sizes, src.options());
3636
}
3737

3838
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
@@ -68,8 +68,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
6868
return out;
6969
}
7070

71-
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
72-
at::optional<at::Tensor> out_opt) {
71+
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
72+
torch::optional<torch::Tensor> out_opt) {
7373
CHECK_CPU(src);
7474
CHECK_CPU(index);
7575
if (out_opt.has_value())
@@ -82,7 +82,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
8282
src = src.contiguous();
8383
auto gather_dim = index.dim() - 1;
8484

85-
at::Tensor out;
85+
torch::Tensor out;
8686
if (out_opt.has_value()) {
8787
out = out_opt.value().contiguous();
8888
for (int i = 0; i < index.dim(); i++)
@@ -92,7 +92,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
9292
} else {
9393
auto sizes = src.sizes().vec();
9494
sizes[gather_dim] = index.size(gather_dim);
95-
out = at::empty(sizes, src.options());
95+
out = torch::empty(sizes, src.options());
9696
}
9797

9898
auto E_1 = index.numel() / out.size(gather_dim);
@@ -139,7 +139,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
139139
return out;
140140
}
141141

142-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
143-
m.def("gather_csr", &gather_csr, "Gather CSR (CPU)");
144-
m.def("gather_coo", &gather_coo, "Gather COO (CPU)");
145-
}
142+
static auto registry =
143+
torch::RegisterOperators("torch_scatter_cpu::gather_csr", &gather_csr)
144+
.op("torch_scatter_cpu::gather_coo", &gather_coo);

cpu/index_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ template <typename scalar_t> struct TensorInfo {
2626
};
2727

2828
template <typename scalar_t>
29-
TensorInfo<scalar_t> getTensorInfo(const at::Tensor &tensor) {
29+
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
3030
int sizes[MAX_TENSORINFO_DIMS];
3131
int strides[MAX_TENSORINFO_DIMS];
3232

cpu/scatter.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include <torch/extension.h>
1+
#include <torch/script.h>
22

33
#include "dim_apply.h"
44

5-
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
5+
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
66

7-
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
7+
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
88
int64_t dim) {
99
CHECK_CPU(src);
1010
CHECK_CPU(index);
@@ -20,7 +20,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
2020
});
2121
}
2222

23-
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
23+
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
2424
int64_t dim) {
2525
CHECK_CPU(src);
2626
CHECK_CPU(index);
@@ -36,8 +36,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
3636
});
3737
}
3838

39-
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
40-
at::Tensor arg, int64_t dim) {
39+
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
40+
torch::Tensor arg, int64_t dim) {
4141
CHECK_CPU(src);
4242
CHECK_CPU(index);
4343
CHECK_CPU(out);
@@ -56,8 +56,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
5656
});
5757
}
5858

59-
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
60-
at::Tensor arg, int64_t dim) {
59+
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
60+
torch::Tensor arg, int64_t dim) {
6161
CHECK_CPU(src);
6262
CHECK_CPU(index);
6363
CHECK_CPU(out);
@@ -77,9 +77,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
7777
});
7878
}
7979

80-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
81-
m.def("scatter_mul", &scatter_mul, "Scatter Mul (CPU)");
82-
m.def("scatter_div", &scatter_div, "Scatter Div (CPU)");
83-
m.def("scatter_max", &scatter_max, "Scatter Max (CPU)");
84-
m.def("scatter_min", &scatter_min, "Scatter Min (CPU)");
85-
}
80+
static auto registry =
81+
torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul)
82+
.op("torch_scatter_cpu::scatter_div", &scatter_div)
83+
.op("torch_scatter_cpu::scatter_max", &scatter_max)
84+
.op("torch_scatter_cpu::scatter_min", &scatter_min);

cpu/segment.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
#include <torch/extension.h>
1+
#include <torch/script.h>
22

33
#include "compat.h"
44
#include "index_info.h"
55

66
#include <vector>
77

8-
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
8+
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
99

1010
enum ReductionType { SUM, MEAN, MIN, MAX };
1111

@@ -74,9 +74,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
7474
}
7575
};
7676

77-
std::tuple<at::Tensor, at::optional<at::Tensor>>
78-
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
79-
std::string reduce) {
77+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
78+
segment_csr(torch::Tensor src, torch::Tensor indptr,
79+
torch::optional<torch::Tensor> out_opt, std::string reduce) {
8080
CHECK_CPU(src);
8181
CHECK_CPU(indptr);
8282
if (out_opt.has_value())
@@ -94,7 +94,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
9494
src = src.contiguous();
9595
auto reduce_dim = indptr.dim() - 1;
9696

97-
at::Tensor out;
97+
torch::Tensor out;
9898
if (out_opt.has_value()) {
9999
out = out_opt.value().contiguous();
100100
for (int i = 0; i < out.dim(); i++)
@@ -105,13 +105,13 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
105105
} else {
106106
sizes = src.sizes().vec();
107107
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
108-
out = at::empty(sizes, src.options());
108+
out = torch::empty(sizes, src.options());
109109
}
110110

111-
at::optional<at::Tensor> arg_out = at::nullopt;
111+
torch::optional<torch::Tensor> arg_out = torch::nullopt;
112112
int64_t *arg_out_data = nullptr;
113113
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
114-
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
114+
arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
115115
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
116116
}
117117

@@ -156,8 +156,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
156156
return std::make_tuple(out, arg_out);
157157
}
158158

159-
std::tuple<at::Tensor, at::optional<at::Tensor>>
160-
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
159+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
160+
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
161161
std::string reduce) {
162162
CHECK_CPU(src);
163163
CHECK_CPU(index);
@@ -180,10 +180,10 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
180180
if (i != reduce_dim)
181181
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
182182

183-
at::optional<at::Tensor> arg_out = at::nullopt;
183+
torch::optional<torch::Tensor> arg_out = torch::nullopt;
184184
int64_t *arg_out_data = nullptr;
185185
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
186-
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
186+
arg_out = torch::full_like(out, src.size(reduce_dim), index.options());
187187
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
188188
}
189189

@@ -251,7 +251,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
251251
return std::make_tuple(out, arg_out);
252252
}
253253

254-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
255-
m.def("segment_csr", &segment_csr, "Segment CSR (CPU)");
256-
m.def("segment_coo", &segment_coo, "Segment COO (CPU)");
257-
}
254+
static auto registry =
255+
torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr)
256+
.op("torch_scatter_cpu::segment_coo", &segment_coo);

cuda/gather.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
1-
#include <torch/extension.h>
1+
#include <torch/script.h>
22

3-
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
3+
#define CHECK_CUDA(x) \
4+
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
45

5-
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
6-
at::optional<at::Tensor> out_opt);
7-
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
8-
at::optional<at::Tensor> out_opt);
6+
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
7+
torch::optional<torch::Tensor> out_opt);
8+
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
9+
torch::optional<torch::Tensor> out_opt);
910

10-
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
11-
at::optional<at::Tensor> out_opt) {
11+
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
12+
torch::optional<torch::Tensor> out_opt) {
1213
CHECK_CUDA(src);
1314
CHECK_CUDA(indptr);
1415
if (out_opt.has_value())
1516
CHECK_CUDA(out_opt.value());
1617
return gather_csr_cuda(src, indptr, out_opt);
1718
}
1819

19-
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
20-
at::optional<at::Tensor> out_opt) {
20+
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
21+
torch::optional<torch::Tensor> out_opt) {
2122
CHECK_CUDA(src);
2223
CHECK_CUDA(index);
2324
if (out_opt.has_value())
2425
CHECK_CUDA(out_opt.value());
2526
return gather_coo_cuda(src, index, out_opt);
2627
}
2728

28-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29-
m.def("gather_csr", &gather_csr, "Gather CSR (CUDA)");
30-
m.def("gather_coo", &gather_coo, "Gather COO (CUDA)");
31-
}
29+
static auto registry =
30+
torch::RegisterOperators("torch_scatter_cuda::gather_csr", &gather_csr)
31+
.op("torch_scatter_cuda::gather_coo", &gather_coo);

cuda/gather_kernel.cu

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include <ATen/ATen.h>
21
#include <ATen/cuda/CUDAContext.h>
32
#include <ATen/cuda/detail/IndexUtils.cuh>
43
#include <ATen/cuda/detail/TensorInfo.cuh>
4+
#include <torch/extension.h>
55

66
#include "compat.cuh"
77
#include "indptr.cuh"
@@ -58,9 +58,10 @@ __global__ void gather_csr_broadcast_kernel(
5858
}
5959
}
6060

61-
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
62-
at::optional<at::Tensor> out_opt) {
61+
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
62+
torch::optional<torch::Tensor> out_opt) {
6363

64+
cudaSetDevice(src.get_device());
6465
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
6566
for (int i = 0; i < indptr.dim() - 1; i++)
6667
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
@@ -70,7 +71,7 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
7071
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
7172
"Input mismatch");
7273

73-
at::Tensor out;
74+
torch::Tensor out;
7475
if (out_opt.has_value()) {
7576
out = out_opt.value().contiguous();
7677
for (int i = 0; i < out.dim(); i++)
@@ -152,8 +153,10 @@ __global__ void gather_coo_broadcast_kernel(
152153
}
153154
}
154155

155-
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
156-
at::optional<at::Tensor> out_opt) {
156+
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
157+
torch::optional<torch::Tensor> out_opt) {
158+
159+
cudaSetDevice(src.get_device());
157160

158161
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
159162
for (int i = 0; i < index.dim() - 1; i++)
@@ -162,7 +165,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
162165
src = src.contiguous();
163166
auto gather_dim = index.dim() - 1;
164167

165-
at::Tensor out;
168+
torch::Tensor out;
166169
if (out_opt.has_value()) {
167170
out = out_opt.value().contiguous();
168171
for (int i = 0; i < index.dim(); i++)
@@ -172,7 +175,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
172175
} else {
173176
auto sizes = src.sizes().vec();
174177
sizes[gather_dim] = index.size(gather_dim);
175-
out = at::empty(sizes, src.options());
178+
out = torch::empty(sizes, src.options());
176179
}
177180

178181
auto E = index.numel();

cuda/index.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include <ATen/ATen.h>
43
#include <ATen/cuda/detail/TensorInfo.cuh>
4+
#include <torch/extension.h>
55

66
template <typename scalar1, typename scalar2, int64_t Dims>
77
struct IndexToScatterOffsets3 {

cuda/indptr.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include <ATen/ATen.h>
43
#include <ATen/cuda/detail/TensorInfo.cuh>
4+
#include <torch/extension.h>
55

66
// We need our own `IndexToOffset` implementation since we do not want to
77
// access the last element of the `indexptr`.

0 commit comments

Comments
 (0)