1- #include < torch/extension .h>
1+ #include < torch/script .h>
22
33#include " compat.h"
44#include " index_info.h"
55
66#include < vector>
77
8- #define CHECK_CPU (x ) AT_ASSERTM(!x.type ().is_cuda (), #x " must be CPU tensor" )
8+ #define CHECK_CPU (x ) AT_ASSERTM(x.device ().is_cpu (), #x " must be CPU tensor" )
99
1010enum ReductionType { SUM, MEAN, MIN, MAX };
1111
@@ -74,9 +74,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
7474 }
7575};
7676
77- std::tuple<at ::Tensor, at ::optional<at ::Tensor>>
78- segment_csr (at ::Tensor src, at ::Tensor indptr, at::optional<at::Tensor> out_opt ,
79- std::string reduce) {
77+ std::tuple<torch ::Tensor, torch ::optional<torch ::Tensor>>
78+ segment_csr (torch ::Tensor src, torch ::Tensor indptr,
79+ torch::optional<torch::Tensor> out_opt, std::string reduce) {
8080 CHECK_CPU (src);
8181 CHECK_CPU (indptr);
8282 if (out_opt.has_value ())
@@ -94,7 +94,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
9494 src = src.contiguous ();
9595 auto reduce_dim = indptr.dim () - 1 ;
9696
97- at ::Tensor out;
97+ torch ::Tensor out;
9898 if (out_opt.has_value ()) {
9999 out = out_opt.value ().contiguous ();
100100 for (int i = 0 ; i < out.dim (); i++)
@@ -105,13 +105,13 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
105105 } else {
106106 sizes = src.sizes ().vec ();
107107 sizes[reduce_dim] = indptr.size (reduce_dim) - 1 ;
108- out = at ::empty (sizes, src.options ());
108+ out = torch ::empty (sizes, src.options ());
109109 }
110110
111- at ::optional<at ::Tensor> arg_out = at ::nullopt ;
111+ torch ::optional<torch ::Tensor> arg_out = torch ::nullopt ;
112112 int64_t *arg_out_data = nullptr ;
113113 if (reduce2REDUCE.at (reduce) == MIN || reduce2REDUCE.at (reduce) == MAX) {
114- arg_out = at ::full_like (out, src.size (reduce_dim), indptr.options ());
114+ arg_out = torch ::full_like (out, src.size (reduce_dim), indptr.options ());
115115 arg_out_data = arg_out.value ().DATA_PTR <int64_t >();
116116 }
117117
@@ -156,8 +156,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
156156 return std::make_tuple (out, arg_out);
157157}
158158
159- std::tuple<at ::Tensor, at ::optional<at ::Tensor>>
160- segment_coo (at ::Tensor src, at ::Tensor index, at ::Tensor out,
159+ std::tuple<torch ::Tensor, torch ::optional<torch ::Tensor>>
160+ segment_coo (torch ::Tensor src, torch ::Tensor index, torch ::Tensor out,
161161 std::string reduce) {
162162 CHECK_CPU (src);
163163 CHECK_CPU (index);
@@ -180,10 +180,10 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
180180 if (i != reduce_dim)
181181 AT_ASSERTM (src.size (i) == out.size (i), " Input mismatch" );
182182
183- at ::optional<at ::Tensor> arg_out = at ::nullopt ;
183+ torch ::optional<torch ::Tensor> arg_out = torch ::nullopt ;
184184 int64_t *arg_out_data = nullptr ;
185185 if (reduce2REDUCE.at (reduce) == MIN || reduce2REDUCE.at (reduce) == MAX) {
186- arg_out = at ::full_like (out, src.size (reduce_dim), index.options ());
186+ arg_out = torch ::full_like (out, src.size (reduce_dim), index.options ());
187187 arg_out_data = arg_out.value ().DATA_PTR <int64_t >();
188188 }
189189
@@ -251,7 +251,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
251251 return std::make_tuple (out, arg_out);
252252}
253253
254- PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
255- m.def (" segment_csr" , &segment_csr, " Segment CSR (CPU)" );
256- m.def (" segment_coo" , &segment_coo, " Segment COO (CPU)" );
257- }
254+ static auto registry =
255+ torch::RegisterOperators (" torch_scatter_cpu::segment_csr" , &segment_csr)
256+ .op(" torch_scatter_cpu::segment_coo" , &segment_coo);
0 commit comments