11#include < torch/extension.h>
22
3+ #include " compat.h"
4+ #include " index_info.h"
5+
36#define CHECK_CPU (x ) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor" )
47
58at::Tensor gather_csr (at::Tensor src, at::Tensor indptr,
@@ -8,8 +11,59 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
811 CHECK_CPU (indptr);
912 if (out_opt.has_value ())
1013 CHECK_CPU (out_opt.value ());
11- AT_ASSERTM (false , " Not yet implemented" );
12- return src;
14+
15+ AT_ASSERTM (src.dim () >= indptr.dim (), " Input mismatch" );
16+ for (int i = 0 ; i < indptr.dim () - 1 ; i++)
17+ AT_ASSERTM (src.size (i) == indptr.size (i), " Input mismatch" );
18+
19+ src = src.contiguous ();
20+ auto gather_dim = indptr.dim () - 1 ;
21+ AT_ASSERTM (src.size (gather_dim) == indptr.size (gather_dim) - 1 ,
22+ " Input mismatch" );
23+
24+ at::Tensor out;
25+ if (out_opt.has_value ()) {
26+ out = out_opt.value ().contiguous ();
27+ for (int i = 0 ; i < out.dim (); i++)
28+ if (i != gather_dim)
29+ AT_ASSERTM (src.size (i) == out.size (i), " Input mismatch" );
30+ } else {
31+ auto sizes = src.sizes ().vec ();
32+ sizes[gather_dim] = *indptr.flatten ()[-1 ].DATA_PTR <int64_t >();
33+ out = at::empty (sizes, src.options ());
34+ }
35+
36+ auto N = src.size (gather_dim) * (indptr.numel () / indptr.size (-1 ));
37+ auto K = src.numel () / N;
38+ auto E = out.size (gather_dim);
39+
40+ auto indptr_info = getTensorInfo<int64_t >(indptr);
41+ auto stride = indptr_info.strides [indptr_info.dims - 1 ];
42+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " gather_csr" , [&] {
43+ auto src_data = src.DATA_PTR <scalar_t >();
44+ auto out_data = out.DATA_PTR <scalar_t >();
45+
46+ scalar_t vals[K];
47+ int64_t row_start, row_end;
48+ for (int n = 0 ; n < N; n++) {
49+ int offset = IndexPtrToOffset<int64_t >::get (n, indptr_info);
50+ row_start = indptr_info.data [offset];
51+ row_end = indptr_info.data [offset + stride];
52+
53+ for (int k = 0 ; k < K; k++) {
54+ vals[k] = src_data[n * K + k];
55+ }
56+
57+ offset = (n / (indptr.size (-1 ) - 1 )) * E * K;
58+ for (int64_t e = row_start; e < row_end; e++) {
59+ for (int k = 0 ; k < K; k++) {
60+ out_data[offset + e * K + k] = vals[k];
61+ }
62+ }
63+ }
64+ });
65+
66+ return out;
1367}
1468
1569at::Tensor gather_coo (at::Tensor src, at::Tensor index,
@@ -18,8 +72,69 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
1872 CHECK_CPU (index);
1973 if (out_opt.has_value ())
2074 CHECK_CPU (out_opt.value ());
21- AT_ASSERTM (false , " Not yet implemented" );
22- return src;
75+
76+ AT_ASSERTM (src.dim () >= index.dim (), " Input mismatch" );
77+ for (int i = 0 ; i < index.dim () - 1 ; i++)
78+ AT_ASSERTM (src.size (i) == index.size (i), " Input mismatch" );
79+
80+ src = src.contiguous ();
81+ auto gather_dim = index.dim () - 1 ;
82+
83+ at::Tensor out;
84+ if (out_opt.has_value ()) {
85+ out = out_opt.value ().contiguous ();
86+ for (int i = 0 ; i < index.dim (); i++)
87+ AT_ASSERTM (out.size (i) == index.size (i), " Input mismatch" );
88+ for (int i = index.dim () + 1 ; i < src.dim (); i++)
89+ AT_ASSERTM (out.size (i) == src.size (i), " Input mismatch" );
90+ } else {
91+ auto sizes = src.sizes ().vec ();
92+ sizes[gather_dim] = index.size (gather_dim);
93+ out = at::empty (sizes, src.options ());
94+ }
95+
96+ auto E_1 = index.numel () / out.size (gather_dim);
97+ auto E_2 = index.size (gather_dim);
98+ auto K = out.numel () / index.numel ();
99+ auto N = src.size (gather_dim);
100+
101+ auto index_info = getTensorInfo<int64_t >(index);
102+ auto stride = index_info.strides [index_info.dims - 1 ];
103+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " gather_coo" , [&] {
104+ auto src_data = src.DATA_PTR <scalar_t >();
105+ auto out_data = out.DATA_PTR <scalar_t >();
106+
107+ scalar_t vals[K];
108+ int64_t idx, next_idx;
109+ for (int e_1 = 0 ; e_1 < E_1; e_1++) {
110+ int offset = IndexToOffset<int64_t >::get (e_1 * E_2, index_info);
111+ idx = index_info.data [offset];
112+
113+ for (int k = 0 ; k < K; k++) {
114+ vals[k] = src_data[e_1 * N * K + idx * K + k];
115+ }
116+
117+ for (int e_2 = 0 ; e_2 < E_2; e_2++) {
118+ for (int k = 0 ; k < K; k++) {
119+ out_data[e_1 * E_2 * K + e_2 * K + k] = vals[k];
120+ }
121+
122+ if (e_2 < E_2 - 1 ) {
123+ next_idx = index_info.data [offset + (e_2 + 1 ) * stride];
124+ assert (idx <= next_idx);
125+
126+ if (idx != next_idx) {
127+ idx = next_idx;
128+ for (int k = 0 ; k < K; k++) {
129+ vals[k] = src_data[e_1 * N * K + idx * K + k];
130+ }
131+ }
132+ }
133+ }
134+ }
135+ });
136+
137+ return out;
23138}
24139
25140PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
0 commit comments