Skip to content

Commit a24b836

Browse files
GericoVirusty1s
andauthored
Export symbols for Windows DLL and MSVC build compatibility (#294)
* Mark exported symbols with SCATTER_API so they are available in the DLL on Windows * macro definition for python package installation * move macros definitions to separate file * unlink? * update * linting Co-authored-by: Matthias Fey <[email protected]>
1 parent 06321b5 commit a24b836

File tree

7 files changed

+101
-63
lines changed

7 files changed

+101
-63
lines changed

csrc/macros.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#ifdef _WIN32
4+
#if defined(torchscatter_EXPORTS)
5+
#define SCATTER_API __declspec(dllexport)
6+
#else
7+
#define SCATTER_API __declspec(dllimport)
8+
#endif
9+
#else
10+
#define SCATTER_API
11+
#endif
12+
13+
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
14+
#define SCATTER_INLINE_VARIABLE inline
15+
#else
16+
#ifdef _MSC_VER
17+
#define SCATTER_INLINE_VARIABLE __declspec(selectany)
18+
#else
19+
#define SCATTER_INLINE_VARIABLE __attribute__((weak))
20+
#endif
21+
#endif

csrc/scatter.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/script.h>
33

44
#include "cpu/scatter_cpu.h"
5+
#include "macros.h"
56
#include "utils.h"
67

78
#ifdef WITH_CUDA
@@ -226,9 +227,10 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> {
226227
}
227228
};
228229

229-
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
230-
torch::optional<torch::Tensor> optional_out,
231-
torch::optional<int64_t> dim_size) {
230+
SCATTER_API torch::Tensor
231+
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
232+
torch::optional<torch::Tensor> optional_out,
233+
torch::optional<int64_t> dim_size) {
232234
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
233235
}
234236

@@ -238,21 +240,22 @@ torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
238240
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
239241
}
240242

241-
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
242-
torch::optional<torch::Tensor> optional_out,
243-
torch::optional<int64_t> dim_size) {
243+
SCATTER_API torch::Tensor
244+
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
245+
torch::optional<torch::Tensor> optional_out,
246+
torch::optional<int64_t> dim_size) {
244247
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
245248
}
246249

247-
std::tuple<torch::Tensor, torch::Tensor>
250+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
248251
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
249252
torch::optional<torch::Tensor> optional_out,
250253
torch::optional<int64_t> dim_size) {
251254
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
252255
return std::make_tuple(result[0], result[1]);
253256
}
254257

255-
std::tuple<torch::Tensor, torch::Tensor>
258+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
256259
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
257260
torch::optional<torch::Tensor> optional_out,
258261
torch::optional<int64_t> dim_size) {

csrc/scatter.h

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,76 +2,76 @@
22

33
#include <torch/extension.h>
44

5-
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
6-
#define SCATTER_INLINE_VARIABLE inline
7-
#else
8-
#ifdef _MSC_VER
9-
#define SCATTER_INLINE_VARIABLE __declspec(selectany)
10-
#else
11-
#define SCATTER_INLINE_VARIABLE __attribute__((weak))
12-
#endif
13-
#endif
5+
#include "macros.h"
146

157
namespace scatter {
16-
int64_t cuda_version() noexcept;
8+
SCATTER_API int64_t cuda_version() noexcept;
179

1810
namespace detail {
1911
SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
2012
} // namespace detail
2113
} // namespace scatter
2214

23-
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
24-
torch::optional<torch::Tensor> optional_out,
25-
torch::optional<int64_t> dim_size);
15+
SCATTER_API torch::Tensor
16+
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
17+
torch::optional<torch::Tensor> optional_out,
18+
torch::optional<int64_t> dim_size);
2619

27-
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
28-
torch::optional<torch::Tensor> optional_out,
29-
torch::optional<int64_t> dim_size);
20+
SCATTER_API torch::Tensor
21+
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
22+
torch::optional<torch::Tensor> optional_out,
23+
torch::optional<int64_t> dim_size);
3024

31-
std::tuple<torch::Tensor, torch::Tensor>
25+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
3226
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
3327
torch::optional<torch::Tensor> optional_out,
3428
torch::optional<int64_t> dim_size);
3529

36-
std::tuple<torch::Tensor, torch::Tensor>
30+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
3731
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
3832
torch::optional<torch::Tensor> optional_out,
3933
torch::optional<int64_t> dim_size);
4034

41-
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
42-
torch::optional<torch::Tensor> optional_out,
43-
torch::optional<int64_t> dim_size);
35+
SCATTER_API torch::Tensor
36+
segment_sum_coo(torch::Tensor src, torch::Tensor index,
37+
torch::optional<torch::Tensor> optional_out,
38+
torch::optional<int64_t> dim_size);
4439

45-
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
46-
torch::optional<torch::Tensor> optional_out,
47-
torch::optional<int64_t> dim_size);
40+
SCATTER_API torch::Tensor
41+
segment_mean_coo(torch::Tensor src, torch::Tensor index,
42+
torch::optional<torch::Tensor> optional_out,
43+
torch::optional<int64_t> dim_size);
4844

49-
std::tuple<torch::Tensor, torch::Tensor>
45+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
5046
segment_min_coo(torch::Tensor src, torch::Tensor index,
5147
torch::optional<torch::Tensor> optional_out,
5248
torch::optional<int64_t> dim_size);
5349

54-
std::tuple<torch::Tensor, torch::Tensor>
50+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
5551
segment_max_coo(torch::Tensor src, torch::Tensor index,
5652
torch::optional<torch::Tensor> optional_out,
5753
torch::optional<int64_t> dim_size);
5854

59-
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
60-
torch::optional<torch::Tensor> optional_out);
55+
SCATTER_API torch::Tensor
56+
gather_coo(torch::Tensor src, torch::Tensor index,
57+
torch::optional<torch::Tensor> optional_out);
6158

62-
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
63-
torch::optional<torch::Tensor> optional_out);
59+
SCATTER_API torch::Tensor
60+
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
61+
torch::optional<torch::Tensor> optional_out);
6462

65-
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
66-
torch::optional<torch::Tensor> optional_out);
63+
SCATTER_API torch::Tensor
64+
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
65+
torch::optional<torch::Tensor> optional_out);
6766

68-
std::tuple<torch::Tensor, torch::Tensor>
67+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
6968
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
7069
torch::optional<torch::Tensor> optional_out);
7170

72-
std::tuple<torch::Tensor, torch::Tensor>
71+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
7372
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
7473
torch::optional<torch::Tensor> optional_out);
7574

76-
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
77-
torch::optional<torch::Tensor> optional_out);
75+
SCATTER_API torch::Tensor
76+
gather_csr(torch::Tensor src, torch::Tensor indptr,
77+
torch::optional<torch::Tensor> optional_out);

csrc/segment_coo.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/script.h>
33

44
#include "cpu/segment_coo_cpu.h"
5+
#include "macros.h"
56
#include "utils.h"
67

78
#ifdef WITH_CUDA
@@ -195,36 +196,39 @@ class GatherCOO : public torch::autograd::Function<GatherCOO> {
195196
}
196197
};
197198

198-
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
199-
torch::optional<torch::Tensor> optional_out,
200-
torch::optional<int64_t> dim_size) {
199+
SCATTER_API torch::Tensor
200+
segment_sum_coo(torch::Tensor src, torch::Tensor index,
201+
torch::optional<torch::Tensor> optional_out,
202+
torch::optional<int64_t> dim_size) {
201203
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
202204
}
203205

204-
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
205-
torch::optional<torch::Tensor> optional_out,
206-
torch::optional<int64_t> dim_size) {
206+
SCATTER_API torch::Tensor
207+
segment_mean_coo(torch::Tensor src, torch::Tensor index,
208+
torch::optional<torch::Tensor> optional_out,
209+
torch::optional<int64_t> dim_size) {
207210
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
208211
}
209212

210-
std::tuple<torch::Tensor, torch::Tensor>
213+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
211214
segment_min_coo(torch::Tensor src, torch::Tensor index,
212215
torch::optional<torch::Tensor> optional_out,
213216
torch::optional<int64_t> dim_size) {
214217
auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size);
215218
return std::make_tuple(result[0], result[1]);
216219
}
217220

218-
std::tuple<torch::Tensor, torch::Tensor>
221+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
219222
segment_max_coo(torch::Tensor src, torch::Tensor index,
220223
torch::optional<torch::Tensor> optional_out,
221224
torch::optional<int64_t> dim_size) {
222225
auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size);
223226
return std::make_tuple(result[0], result[1]);
224227
}
225228

226-
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
227-
torch::optional<torch::Tensor> optional_out) {
229+
SCATTER_API torch::Tensor
230+
gather_coo(torch::Tensor src, torch::Tensor index,
231+
torch::optional<torch::Tensor> optional_out) {
228232
return GatherCOO::apply(src, index, optional_out)[0];
229233
}
230234

csrc/segment_csr.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/script.h>
33

44
#include "cpu/segment_csr_cpu.h"
5+
#include "macros.h"
56
#include "utils.h"
67

78
#ifdef WITH_CUDA
@@ -192,32 +193,35 @@ class GatherCSR : public torch::autograd::Function<GatherCSR> {
192193
}
193194
};
194195

195-
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
196-
torch::optional<torch::Tensor> optional_out) {
196+
SCATTER_API torch::Tensor
197+
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
198+
torch::optional<torch::Tensor> optional_out) {
197199
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
198200
}
199201

200-
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
201-
torch::optional<torch::Tensor> optional_out) {
202+
SCATTER_API torch::Tensor
203+
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
204+
torch::optional<torch::Tensor> optional_out) {
202205
return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
203206
}
204207

205-
std::tuple<torch::Tensor, torch::Tensor>
208+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
206209
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
207210
torch::optional<torch::Tensor> optional_out) {
208211
auto result = SegmentMinCSR::apply(src, indptr, optional_out);
209212
return std::make_tuple(result[0], result[1]);
210213
}
211214

212-
std::tuple<torch::Tensor, torch::Tensor>
215+
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
213216
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
214217
torch::optional<torch::Tensor> optional_out) {
215218
auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
216219
return std::make_tuple(result[0], result[1]);
217220
}
218221

219-
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
220-
torch::optional<torch::Tensor> optional_out) {
222+
SCATTER_API torch::Tensor
223+
gather_csr(torch::Tensor src, torch::Tensor indptr,
224+
torch::optional<torch::Tensor> optional_out) {
221225
return GatherCSR::apply(src, indptr, optional_out)[0];
222226
}
223227

csrc/version.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <Python.h>
22
#include <torch/script.h>
3+
#include "scatter.h"
4+
#include "macros.h"
35

46
#ifdef WITH_CUDA
57
#include <cuda.h>
@@ -14,7 +16,7 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
1416
#endif
1517

1618
namespace scatter {
17-
int64_t cuda_version() noexcept {
19+
SCATTER_API int64_t cuda_version() noexcept {
1820
#ifdef WITH_CUDA
1921
return CUDA_VERSION;
2022
#else

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def get_extensions():
3535

3636
for main, suffix in product(main_files, suffices):
3737
define_macros = []
38+
39+
if sys.platform == 'win32':
40+
define_macros += [('torchscatter_EXPORTS', None)]
41+
3842
extra_compile_args = {'cxx': ['-O2']}
3943
if not os.name == 'nt': # Not on Windows:
4044
extra_compile_args['cxx'] += ['-Wno-sign-compare']

0 commit comments

Comments
 (0)