11#include < torch/extension.h>
22
3+ #include " compat.h"
4+ #include " index_info.h"
5+
36#define CHECK_CPU (x ) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor" )
47
8+ enum ReductionType { ADD, MEAN, MIN, MAX };
9+
10+ #define AT_DISPATCH_REDUCTION_TYPES (reduce, ...) \
11+ [&] { \
12+ if (reduce == " add" ) { \
13+ const ReductionType REDUCE = ADD; \
14+ return __VA_ARGS__ (); \
15+ } else if (reduce == " mean" ) { \
16+ const ReductionType REDUCE = MEAN; \
17+ return __VA_ARGS__ (); \
18+ } else if (reduce == " min" ) { \
19+ const ReductionType REDUCE = MIN; \
20+ return __VA_ARGS__ (); \
21+ } else if (reduce == " max" ) { \
22+ const ReductionType REDUCE = MAX; \
23+ return __VA_ARGS__ (); \
24+ } \
25+ }()
26+
27+ template <typename scalar_t , ReductionType REDUCE> struct Reducer {
28+ static inline scalar_t init () {
29+ if (REDUCE == MIN) {
30+ return std::numeric_limits<scalar_t >::max ();
31+ } else if (REDUCE == MAX) {
32+ return std::numeric_limits<scalar_t >::lowest ();
33+ } else {
34+ return (scalar_t )0 ;
35+ }
36+ }
37+
38+ static inline void update (scalar_t *val, scalar_t new_val) {
39+ if (REDUCE == ADD || REDUCE == MEAN) {
40+ *val = *val + new_val;
41+ } else if ((REDUCE == MIN && new_val < *val) ||
42+ (REDUCE == MAX && new_val > *val)) {
43+ *val = new_val;
44+ }
45+ }
46+
47+ static inline void update (scalar_t *val, scalar_t new_val, int64_t *arg,
48+ int64_t new_arg) {
49+ if (REDUCE == ADD || REDUCE == MEAN) {
50+ *val = *val + new_val;
51+ } else if ((REDUCE == MIN && new_val < *val) ||
52+ (REDUCE == MAX && new_val > *val)) {
53+ *val = new_val;
54+ *arg = new_arg;
55+ }
56+ }
57+
58+ static inline void write (scalar_t *address, scalar_t val,
59+ int64_t *arg_address, int64_t arg, int count) {
60+ if (REDUCE == ADD) {
61+ *address = val;
62+ } else if (REDUCE == MEAN) {
63+ *address = val / (count > 0 ? count : (scalar_t )1 );
64+ } else if (REDUCE == MIN || REDUCE == MAX) {
65+ if (count > 0 ) {
66+ *address = val;
67+ *arg_address = arg;
68+ } else {
69+ *address = (scalar_t )0 ;
70+ }
71+ }
72+ }
73+ };
74+
575std::tuple<at::Tensor, at::optional<at::Tensor>>
676segment_csr (at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
777 std::string reduce) {
878 CHECK_CPU (src);
979 CHECK_CPU (indptr);
1080 if (out_opt.has_value ())
1181 CHECK_CPU (out_opt.value ());
12- AT_ASSERTM (false , " Not yet implemented" );
13- return std::make_tuple (src, at::nullopt );
82+
83+ AT_ASSERTM (src.dim () >= indptr.dim (), " Input mismatch" );
84+
85+ // Broadcasting `indptr` via `expand`.
86+ auto sizes = indptr.sizes ().vec ();
87+ for (int i = 0 ; i < indptr.dim () - 1 ; i++) {
88+ sizes[i] = src.size (i);
89+ }
90+ indptr = indptr.expand (sizes);
91+
92+ src = src.contiguous ();
93+ auto reduce_dim = indptr.dim () - 1 ;
94+
95+ at::Tensor out;
96+ if (out_opt.has_value ()) {
97+ out = out_opt.value ().contiguous ();
98+ for (int i = 0 ; i < out.dim (); i++)
99+ if (i != reduce_dim)
100+ AT_ASSERTM (src.size (i) == out.size (i), " Input mismatch" );
101+ AT_ASSERTM (out.size (reduce_dim) == indptr.size (reduce_dim) - 1 ,
102+ " Input mismatch" );
103+ } else {
104+ sizes = src.sizes ().vec ();
105+ sizes[reduce_dim] = indptr.size (reduce_dim) - 1 ;
106+ out = at::empty (sizes, src.options ());
107+ }
108+
109+ at::optional<at::Tensor> arg_out = at::nullopt ;
110+ int64_t *arg_out_data = nullptr ;
111+ if (reduce == " min" || reduce == " max" ) {
112+ arg_out = at::full_like (out, src.size (reduce_dim), indptr.options ());
113+ arg_out_data = arg_out.value ().DATA_PTR <int64_t >();
114+ }
115+
116+ auto N = out.size (reduce_dim) * (indptr.numel () / indptr.size (-1 ));
117+ auto K = out.numel () / N;
118+ auto E = src.size (reduce_dim);
119+
120+ auto indptr_info = getTensorInfo<int64_t >(indptr);
121+ auto stride = indptr_info.strides [indptr_info.dims - 1 ];
122+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " segment_csr" , [&] {
123+ auto src_data = src.DATA_PTR <scalar_t >();
124+ auto out_data = out.DATA_PTR <scalar_t >();
125+
126+ scalar_t val;
127+ int64_t row_start, row_end, arg;
128+ AT_DISPATCH_REDUCTION_TYPES (reduce, [&] {
129+ for (int n = 0 ; n < N; n++) {
130+ int offset = IndexPtrToOffset<int64_t >::get (n, indptr_info);
131+ row_start = indptr_info.data [offset];
132+ row_end = indptr_info.data [offset + stride];
133+
134+ offset = (n / (indptr.size (-1 ) - 1 )) * E * K;
135+ for (int k = 0 ; k < K; k++) {
136+ val = Reducer<scalar_t , REDUCE>::init ();
137+ for (int64_t e = row_start; e < row_end; e++) {
138+ Reducer<scalar_t , REDUCE>::update (
139+ &val, src_data[offset + e * K + k], &arg, e);
140+ }
141+ Reducer<scalar_t , REDUCE>::write (out_data + n * K + k, val,
142+ arg_out_data + n * K + k, arg,
143+ row_end - row_start);
144+ }
145+ }
146+ });
147+ });
148+
149+ return std::make_tuple (out, arg_out);
14150}
15151
16152std::tuple<at::Tensor, at::optional<at::Tensor>>
@@ -19,8 +155,84 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
19155 CHECK_CPU (src);
20156 CHECK_CPU (index);
21157 CHECK_CPU (out);
22- AT_ASSERTM (false , " Not yet implemented" );
23- return std::make_tuple (src, at::nullopt );
158+
159+ AT_ASSERTM (src.dim () >= index.dim (), " Input mismatch" );
160+
161+ // Broadcasting `index` via `expand`.
162+ auto sizes = index.sizes ().vec ();
163+ for (int i = 0 ; i < index.dim (); i++) {
164+ sizes[i] = src.size (i);
165+ }
166+ index = index.expand (sizes);
167+
168+ src = src.contiguous ();
169+ out = out.contiguous ();
170+ auto reduce_dim = index.dim () - 1 ;
171+
172+ for (int i = 0 ; i < out.dim (); i++)
173+ if (i != reduce_dim)
174+ AT_ASSERTM (src.size (i) == out.size (i), " Input mismatch" );
175+
176+ at::optional<at::Tensor> arg_out = at::nullopt ;
177+ int64_t *arg_out_data = nullptr ;
178+ if (reduce == " min" || reduce == " max" ) {
179+ arg_out = at::full_like (out, src.size (reduce_dim), index.options ());
180+ arg_out_data = arg_out.value ().DATA_PTR <int64_t >();
181+ }
182+
183+ auto E_1 = index.numel () / src.size (reduce_dim);
184+ auto E_2 = src.size (reduce_dim);
185+ auto K = src.numel () / index.numel ();
186+ auto N = out.size (reduce_dim);
187+
188+ auto index_info = getTensorInfo<int64_t >(index);
189+ auto stride = index_info.strides [index_info.dims - 1 ];
190+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " segment_coo" , [&] {
191+ auto src_data = src.DATA_PTR <scalar_t >();
192+ auto out_data = out.DATA_PTR <scalar_t >();
193+
194+ scalar_t val;
195+ int64_t idx, next_idx, row_start, arg;
196+ AT_DISPATCH_REDUCTION_TYPES (reduce, [&] {
197+ for (int e_1 = 0 ; e_1 < E_1; e_1++) {
198+ int offset = IndexToOffset<int64_t >::get (e_1 * E_2, index_info);
199+
200+ for (int k = 0 ; k < K; k++) {
201+ idx = index_info.data [offset];
202+ row_start = 0 ;
203+ val = out_data[e_1 * N * K + k];
204+
205+ for (int e_2 = 0 ; e_2 < E_2; e_2++) {
206+ Reducer<scalar_t , REDUCE>::update (
207+ &val, src_data[e_1 * E_2 * K + e_2 * K + k], &arg, e_2);
208+
209+ if (e_2 == E_2 - 1 ) {
210+ Reducer<scalar_t , REDUCE>::write (
211+ out_data + e_1 * N * K + idx * K + k, val,
212+ arg_out_data + e_1 * N * K + idx * K + k, arg,
213+ e_2 + 1 - row_start);
214+ } else {
215+ next_idx = index_info.data [offset + (e_2 + 1 ) * stride];
216+
217+ if (idx != next_idx) {
218+ Reducer<scalar_t , REDUCE>::write (
219+ out_data + e_1 * N * K + idx * K + k, val,
220+ arg_out_data + e_1 * N * K + idx * K + k, arg,
221+ e_2 + 1 - row_start);
222+
223+ row_start = e_2 + 1 ;
224+ val = out_data[e_1 * N * K + next_idx * K + k];
225+ }
226+
227+ idx = next_idx;
228+ }
229+ }
230+ }
231+ }
232+ });
233+ });
234+
235+ return std::make_tuple (out, arg_out);
24236}
25237
26238PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
0 commit comments