|
2 | 2 |
|
3 | 3 | #include <torch/extension.h> |
4 | 4 |
|
5 | | -#if (defined __cpp_inline_variables) || __cplusplus >= 201703L |
6 | | -#define SCATTER_INLINE_VARIABLE inline |
7 | | -#else |
8 | | -#ifdef _MSC_VER |
9 | | -#define SCATTER_INLINE_VARIABLE __declspec(selectany) |
10 | | -#else |
11 | | -#define SCATTER_INLINE_VARIABLE __attribute__((weak)) |
12 | | -#endif |
13 | | -#endif |
| 5 | +#include "macros.h" |
14 | 6 |
|
15 | 7 | namespace scatter { |
16 | | -int64_t cuda_version() noexcept; |
| 8 | +SCATTER_API int64_t cuda_version() noexcept; |
17 | 9 |
|
18 | 10 | namespace detail { |
19 | 11 | SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); |
20 | 12 | } // namespace detail |
21 | 13 | } // namespace scatter |
22 | 14 |
|
23 | | -torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, |
24 | | - torch::optional<torch::Tensor> optional_out, |
25 | | - torch::optional<int64_t> dim_size); |
| 15 | +SCATTER_API torch::Tensor |
| 16 | +scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, |
| 17 | + torch::optional<torch::Tensor> optional_out, |
| 18 | + torch::optional<int64_t> dim_size); |
26 | 19 |
|
27 | | -torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, |
28 | | - torch::optional<torch::Tensor> optional_out, |
29 | | - torch::optional<int64_t> dim_size); |
| 20 | +SCATTER_API torch::Tensor |
| 21 | +scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, |
| 22 | + torch::optional<torch::Tensor> optional_out, |
| 23 | + torch::optional<int64_t> dim_size); |
30 | 24 |
|
31 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 25 | +SCATTER_API std::tuple<torch::Tensor, torch::Tensor> |
32 | 26 | scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, |
33 | 27 | torch::optional<torch::Tensor> optional_out, |
34 | 28 | torch::optional<int64_t> dim_size); |
35 | 29 |
|
36 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 30 | +SCATTER_API std::tuple<torch::Tensor, torch::Tensor> |
37 | 31 | scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, |
38 | 32 | torch::optional<torch::Tensor> optional_out, |
39 | 33 | torch::optional<int64_t> dim_size); |
40 | 34 |
|
41 | | -torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, |
42 | | - torch::optional<torch::Tensor> optional_out, |
43 | | - torch::optional<int64_t> dim_size); |
| 35 | +SCATTER_API torch::Tensor |
| 36 | +segment_sum_coo(torch::Tensor src, torch::Tensor index, |
| 37 | + torch::optional<torch::Tensor> optional_out, |
| 38 | + torch::optional<int64_t> dim_size); |
44 | 39 |
|
45 | | -torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, |
46 | | - torch::optional<torch::Tensor> optional_out, |
47 | | - torch::optional<int64_t> dim_size); |
| 40 | +SCATTER_API torch::Tensor |
| 41 | +segment_mean_coo(torch::Tensor src, torch::Tensor index, |
| 42 | + torch::optional<torch::Tensor> optional_out, |
| 43 | + torch::optional<int64_t> dim_size); |
48 | 44 |
|
49 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 45 | +SCATTER_API std::tuple<torch::Tensor, torch::Tensor> |
50 | 46 | segment_min_coo(torch::Tensor src, torch::Tensor index, |
51 | 47 | torch::optional<torch::Tensor> optional_out, |
52 | 48 | torch::optional<int64_t> dim_size); |
53 | 49 |
|
54 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 50 | +SCATTER_API std::tuple<torch::Tensor, torch::Tensor> |
55 | 51 | segment_max_coo(torch::Tensor src, torch::Tensor index, |
56 | 52 | torch::optional<torch::Tensor> optional_out, |
57 | 53 | torch::optional<int64_t> dim_size); |
58 | 54 |
|
59 | | -torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, |
60 | | - torch::optional<torch::Tensor> optional_out); |
| 55 | +SCATTER_API torch::Tensor |
| 56 | +gather_coo(torch::Tensor src, torch::Tensor index, |
| 57 | + torch::optional<torch::Tensor> optional_out); |
61 | 58 |
|
62 | | -torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, |
63 | | - torch::optional<torch::Tensor> optional_out); |
| 59 | +SCATTER_API torch::Tensor |
| 60 | +segment_sum_csr(torch::Tensor src, torch::Tensor indptr, |
| 61 | + torch::optional<torch::Tensor> optional_out); |
64 | 62 |
|
65 | | -torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, |
66 | | - torch::optional<torch::Tensor> optional_out); |
| 63 | +SCATTER_API torch::Tensor |
| 64 | +segment_mean_csr(torch::Tensor src, torch::Tensor indptr, |
| 65 | + torch::optional<torch::Tensor> optional_out); |
67 | 66 |
|
68 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 67 | +SCATTER_API std::tuple<torch::Tensor, torch::Tensor> |
69 | 68 | segment_min_csr(torch::Tensor src, torch::Tensor indptr, |
70 | 69 | torch::optional<torch::Tensor> optional_out); |
71 | 70 |
|
72 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 71 | +SCATTER_API std::tuple<torch::Tensor, torch::Tensor> |
73 | 72 | segment_max_csr(torch::Tensor src, torch::Tensor indptr, |
74 | 73 | torch::optional<torch::Tensor> optional_out); |
75 | 74 |
|
76 | | -torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, |
77 | | - torch::optional<torch::Tensor> optional_out); |
| 75 | +SCATTER_API torch::Tensor |
| 76 | +gather_csr(torch::Tensor src, torch::Tensor indptr, |
| 77 | + torch::optional<torch::Tensor> optional_out); |
0 commit comments