|
5 | 5 |
|
6 | 6 | # This constant used to be 1 in both MPICH and OpenMPI, |
7 | 7 | # but starting with mpi4py version 4, they switched it to -1. |
| 8 | +# Even worse, starting with 4.1, it's not a subtype of int as it had been until then. |
| 9 | +# But it does cast to an int if requested. |
8 | 10 | # Use -1 here, but when we check for it allow 1 as well. |
9 | | -# And if we happen to have mpi4py installed, include whatever it actually has as well. |
| 11 | +# And if we happen to have mpi4py installed, include whatever it actually has as well |
| 12 | +# both for the value and the type. |
10 | 13 | try: |
11 | 14 | from mpi4py.MPI import IN_PLACE |
| 15 | + ALLOWED_IN_PLACE_TYPES = (int, type(IN_PLACE)) |
12 | 16 | except ImportError: |
13 | 17 | IN_PLACE = -1 |
| 18 | + ALLOWED_IN_PLACE_TYPES = (int,) |
14 | 19 | ALLOWED_IN_PLACE = [IN_PLACE, 1, -1] |
15 | 20 |
|
16 | 21 |
|
@@ -119,7 +124,7 @@ def allreduce(self, sendobj, op=None): |
119 | 124 | return d |
120 | 125 |
|
121 | 126 | def Reduce(self, sendbuf, recvbuf, op=None, root=0): |
122 | | - if isinstance(sendbuf, int) and (sendbuf in ALLOWED_IN_PLACE): |
| 127 | + if isinstance(sendbuf, ALLOWED_IN_PLACE_TYPES) and (sendbuf in ALLOWED_IN_PLACE): |
123 | 128 | sendbuf = recvbuf.copy() |
124 | 129 |
|
125 | 130 | if not isinstance(sendbuf, np.ndarray): |
|
0 commit comments