Skip to content

Commit 14efb6f

Browse files
DaskExecutor (#943)
* adds DaskScheduler * upgrades * no ci for 3.8 * lazy import dask * updates * drop py3.7 support * lockfile * fixes? * future * fixes * fixes * fixes * fixes for 3.8 * update cassettes * ci * changelog and DaskExecutor.from_kwargs * fixes as_completed * refactor Executor type to protocol * fixes protocol * fixes * Update webknossos/Changelog.md Co-authored-by: Philipp Otto <[email protected]> * Changelog --------- Co-authored-by: Philipp Otto <[email protected]>
1 parent 1cb7101 commit 14efb6f

File tree

60 files changed

+3532
-5710
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+3532
-5710
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,15 @@ jobs:
9494
poetry install
9595
9696
- name: Check typing
97-
if: ${{ matrix.executors == 'multiprocessing' }}
97+
if: ${{ matrix.executors == 'multiprocessing' && matrix.python-version == '3.11' }}
9898
run: ./typecheck.sh
9999

100100
- name: Check formatting
101-
if: ${{ matrix.executors == 'multiprocessing' }}
101+
if: ${{ matrix.executors == 'multiprocessing' && matrix.python-version == '3.11' }}
102102
run: ./format.sh check
103103

104104
- name: Lint code
105-
if: ${{ matrix.executors == 'multiprocessing' }}
105+
if: ${{ matrix.executors == 'multiprocessing' && matrix.python-version == '3.11' }}
106106
run: ./lint.sh
107107

108108
- name: Run multiprocessing tests
@@ -160,15 +160,15 @@ jobs:
160160
poetry install --extras all
161161
162162
- name: Check formatting
163-
if: matrix.group == 1
163+
if: ${{ matrix.group == 1 && matrix.python-version == '3.11' }}
164164
run: ./format.sh check
165165

166166
- name: Lint code
167-
if: matrix.group == 1
167+
if: ${{ matrix.group == 1 && matrix.python-version == '3.11' }}
168168
run: ./lint.sh
169169

170170
- name: Check typing
171-
if: matrix.group == 1
171+
if: ${{ matrix.group == 1 && matrix.python-version == '3.11' }}
172172
run: ./typecheck.sh
173173

174174
- name: Python tests

cluster_tools/Changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@ For upgrade instructions, please check the respective *Breaking Changes* section
1010
[Commits](https://github.com/scalableminds/webknossos-libs/compare/v0.13.7...HEAD)
1111

1212
### Breaking Changes
13+
- Dropped support for Python 3.7. [#943](https://github.com/scalableminds/webknossos-libs/pull/943)
14+
- Please use `Executor.as_completed` instead of `concurrent.futures.as_completed` because the latter will not work for `DaskExecutor` futures. [#943](https://github.com/scalableminds/webknossos-libs/pull/943)
1315

1416
### Added
17+
- Added `DaskScheduler` (only Python >= 3.9). [#943](https://github.com/scalableminds/webknossos-libs/pull/943)
1518

1619
### Changed
20+
- The exported `Executor` type is now implemented as a protocol. [#943](https://github.com/scalableminds/webknossos-libs/pull/943)
1721

1822
### Fixed
1923

cluster_tools/cluster_tools/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any, Union, overload
2-
3-
from typing_extensions import Literal
1+
from typing import Any, Literal, overload
42

3+
from cluster_tools.executor_protocol import Executor
4+
from cluster_tools.executors.dask import DaskExecutor
55
from cluster_tools.executors.debug_sequential import DebugSequentialExecutor
66
from cluster_tools.executors.multiprocessing_ import MultiprocessingExecutor
77
from cluster_tools.executors.pickle_ import PickleExecutor
@@ -70,6 +70,11 @@ def get_executor(
7070
...
7171

7272

73+
@overload
74+
def get_executor(environment: Literal["dask"], **kwargs: Any) -> DaskExecutor:
75+
...
76+
77+
7378
@overload
7479
def get_executor(
7580
environment: Literal["multiprocessing"], **kwargs: Any
@@ -105,6 +110,11 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor":
105110
return PBSExecutor(**kwargs)
106111
elif environment == "kubernetes":
107112
return KubernetesExecutor(**kwargs)
113+
elif environment == "dask":
114+
if "client" in kwargs:
115+
return DaskExecutor(kwargs["client"])
116+
else:
117+
return DaskExecutor.from_kwargs(**kwargs)
108118
elif environment == "multiprocessing":
109119
global did_start_test_multiprocessing
110120
if not did_start_test_multiprocessing:
@@ -119,6 +129,3 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor":
119129
elif environment == "test_pickling":
120130
return PickleExecutor(**kwargs)
121131
raise Exception("Unknown executor: {}".format(environment))
122-
123-
124-
Executor = Union[ClusterExecutor, MultiprocessingExecutor]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from concurrent.futures import Future
2+
from os import PathLike
3+
from typing import (
4+
Callable,
5+
ContextManager,
6+
Iterable,
7+
Iterator,
8+
List,
9+
Optional,
10+
Protocol,
11+
TypeVar,
12+
)
13+
14+
from typing_extensions import ParamSpec
15+
16+
_T = TypeVar("_T")
17+
_P = ParamSpec("_P")
18+
_S = TypeVar("_S")
19+
20+
21+
class Executor(Protocol, ContextManager["Executor"]):
22+
@classmethod
23+
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
24+
...
25+
26+
def submit(
27+
self,
28+
__fn: Callable[_P, _T],
29+
/,
30+
*args: _P.args,
31+
**kwargs: _P.kwargs,
32+
) -> "Future[_T]":
33+
...
34+
35+
def map_unordered(self, fn: Callable[[_S], _T], args: Iterable[_S]) -> Iterator[_T]:
36+
...
37+
38+
def map_to_futures(
39+
self,
40+
fn: Callable[[_S], _T],
41+
args: Iterable[_S],
42+
output_pickle_path_getter: Optional[Callable[[_S], PathLike]] = None,
43+
) -> List["Future[_T]"]:
44+
...
45+
46+
def map(
47+
self,
48+
fn: Callable[[_S], _T],
49+
iterables: Iterable[_S],
50+
timeout: Optional[float] = None,
51+
chunksize: Optional[int] = None,
52+
) -> Iterator[_T]:
53+
...
54+
55+
def forward_log(self, fut: "Future[_T]") -> _T:
56+
...
57+
58+
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
59+
...
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
from concurrent import futures
3+
from concurrent.futures import Future
4+
from functools import partial
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Callable,
9+
Iterable,
10+
Iterator,
11+
List,
12+
Optional,
13+
TypeVar,
14+
cast,
15+
)
16+
17+
from typing_extensions import ParamSpec
18+
19+
from cluster_tools._utils.warning import enrich_future_with_uncaught_warning
20+
from cluster_tools.executors.multiprocessing_ import CFutDict, MultiprocessingExecutor
21+
22+
if TYPE_CHECKING:
23+
from distributed import Client
24+
25+
_T = TypeVar("_T")
26+
_P = ParamSpec("_P")
27+
_S = TypeVar("_S")
28+
29+
30+
class DaskExecutor(futures.Executor):
31+
client: "Client"
32+
33+
def __init__(
34+
self,
35+
client: "Client",
36+
) -> None:
37+
self.client = client
38+
39+
@classmethod
40+
def from_kwargs(
41+
cls,
42+
**kwargs: Any,
43+
) -> "DaskExecutor":
44+
from distributed import Client
45+
46+
return cls(Client(**kwargs))
47+
48+
@classmethod
49+
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
50+
from distributed import as_completed
51+
52+
return as_completed(futures)
53+
54+
def submit( # type: ignore[override]
55+
self,
56+
__fn: Callable[_P, _T],
57+
*args: _P.args,
58+
**kwargs: _P.kwargs,
59+
) -> "Future[_T]":
60+
if "__cfut_options" in kwargs:
61+
output_pickle_path = cast(CFutDict, kwargs["__cfut_options"])[
62+
"output_pickle_path"
63+
]
64+
del kwargs["__cfut_options"]
65+
66+
__fn = partial(
67+
MultiprocessingExecutor._execute_and_persist_function,
68+
output_pickle_path,
69+
__fn,
70+
)
71+
fut = self.client.submit(partial(__fn, *args, **kwargs))
72+
73+
enrich_future_with_uncaught_warning(fut)
74+
return fut
75+
76+
def map_unordered(self, fn: Callable[[_S], _T], args: Iterable[_S]) -> Iterator[_T]:
77+
futs: List["Future[_T]"] = self.map_to_futures(fn, args)
78+
79+
# Return a separate generator to avoid that map_unordered
80+
# is executed lazily (otherwise, jobs would be submitted
81+
# lazily, as well).
82+
def result_generator() -> Iterator:
83+
for fut in self.as_completed(futs):
84+
yield fut.result()
85+
86+
return result_generator()
87+
88+
def map_to_futures(
89+
self,
90+
fn: Callable[[_S], _T],
91+
args: Iterable[_S], # TODO change: allow more than one arg per call
92+
output_pickle_path_getter: Optional[Callable[[_S], os.PathLike]] = None,
93+
) -> List["Future[_T]"]:
94+
if output_pickle_path_getter is not None:
95+
futs = [
96+
self.submit( # type: ignore[call-arg]
97+
fn,
98+
arg,
99+
__cfut_options={
100+
"output_pickle_path": output_pickle_path_getter(arg)
101+
},
102+
)
103+
for arg in args
104+
]
105+
else:
106+
futs = [self.submit(fn, arg) for arg in args]
107+
108+
return futs
109+
110+
def map( # type: ignore[override]
111+
self,
112+
fn: Callable[[_S], _T],
113+
iterables: Iterable[_S],
114+
timeout: Optional[float] = None,
115+
chunksize: Optional[int] = None,
116+
) -> Iterator[_T]:
117+
if chunksize is None:
118+
chunksize = 1
119+
return super().map(fn, iterables, timeout=timeout, chunksize=chunksize)
120+
121+
def forward_log(self, fut: "Future[_T]") -> _T:
122+
return fut.result()
123+
124+
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
125+
if wait:
126+
self.client.close(timeout=60 * 60 * 24)
127+
else:
128+
self.client.close()

cluster_tools/cluster_tools/executors/multiprocessing_.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
List,
1717
Optional,
1818
Tuple,
19+
TypedDict,
1920
TypeVar,
2021
cast,
2122
)
2223

23-
from typing_extensions import ParamSpec, TypedDict
24+
from typing_extensions import ParamSpec
2425

2526
from cluster_tools._utils import pickling
2627
from cluster_tools._utils.multiprocessing_logging_handler import (
@@ -85,6 +86,10 @@ def __init__(
8586
else:
8687
self._mp_logging_handler_pool = _MultiprocessingLoggingHandlerPool()
8788

89+
@classmethod
90+
def as_completed(cls, futs: List["Future[_T]"]) -> Iterator["Future[_T]"]:
91+
return futures.as_completed(futs)
92+
8893
def submit( # type: ignore[override]
8994
self,
9095
__fn: Callable[_P, _T],
@@ -143,6 +148,17 @@ def submit( # type: ignore[override]
143148
enrich_future_with_uncaught_warning(fut)
144149
return fut
145150

151+
def map( # type: ignore[override]
152+
self,
153+
fn: Callable[[_S], _T],
154+
iterables: Iterable[_S],
155+
timeout: Optional[float] = None,
156+
chunksize: Optional[int] = None,
157+
) -> Iterator[_T]:
158+
if chunksize is None:
159+
chunksize = 1
160+
return super().map(fn, iterables, timeout=timeout, chunksize=chunksize)
161+
146162
def _submit_via_io(
147163
self,
148164
__fn: Callable[_P, _T],
Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
from concurrent.futures import Future
2-
from typing import Any, Callable, TypeVar
2+
from functools import partial
3+
from typing import Callable, TypeVar
4+
5+
from typing_extensions import ParamSpec
36

47
from cluster_tools._utils import pickling
58
from cluster_tools.executors.multiprocessing_ import MultiprocessingExecutor
69

710
# The module name includes a _-suffix to avoid name clashes with the standard library pickle module.
811

912
_T = TypeVar("_T")
13+
_P = ParamSpec("_P")
14+
_S = TypeVar("_S")
1015

1116

12-
def _pickle_identity(obj: _T) -> _T:
17+
def _pickle_identity(obj: _S) -> _S:
1318
return pickling.loads(pickling.dumps(obj))
1419

1520

16-
def _pickle_identity_executor(fn: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
21+
def _pickle_identity_executor(
22+
fn: Callable[_P, _T],
23+
*args: _P.args,
24+
**kwargs: _P.kwargs,
25+
) -> _T:
1726
result = fn(*args, **kwargs)
1827
return _pickle_identity(result)
1928

@@ -27,13 +36,16 @@ class PickleExecutor(MultiprocessingExecutor):
2736

2837
def submit( # type: ignore[override]
2938
self,
30-
fn: Callable[..., _T],
31-
*args: Any,
32-
**kwargs: Any,
39+
fn: Callable[_P, _T],
40+
/,
41+
*args: _P.args,
42+
**kwargs: _P.kwargs,
3343
) -> "Future[_T]":
3444
(fn_pickled, args_pickled, kwargs_pickled) = _pickle_identity(
3545
(fn, args, kwargs)
3646
)
3747
return super().submit(
38-
_pickle_identity_executor, fn_pickled, *args_pickled, **kwargs_pickled
48+
partial(_pickle_identity_executor, fn_pickled),
49+
*args_pickled,
50+
**kwargs_pickled,
3951
)

0 commit comments

Comments
 (0)