Skip to content

Commit 336e2b6

Browse files
authored
More dask features (#959)
* upgrades to mypy 1.6 * pr feedback * changelog * adds sigint, mem and cpus support * changelog * weakref handle_kill * test dask in CI * typing * ci * ci * fix tests
1 parent 2662300 commit 336e2b6

File tree

5 files changed

+183
-14
lines changed

5 files changed

+183
-14
lines changed

.github/workflows/ci.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
strategy:
3838
max-parallel: 4
3939
matrix:
40-
executors: [multiprocessing, slurm, kubernetes]
40+
executors: [multiprocessing, slurm, kubernetes, dask]
4141
python-version: ["3.11", "3.10", "3.9", "3.8"]
4242
defaults:
4343
run:
@@ -88,7 +88,7 @@ jobs:
8888
./kind load docker-image scalableminds/cluster-tools:latest
8989
9090
- name: Install dependencies (without docker)
91-
if: ${{ matrix.executors == 'multiprocessing' || matrix.executors == 'kubernetes' }}
91+
if: ${{ matrix.executors != 'slurm' }}
9292
run: |
9393
pip install -r ../requirements.txt
9494
poetry install
@@ -130,6 +130,12 @@ jobs:
130130
cd tests
131131
PYTEST_EXECUTORS=kubernetes poetry run python -m pytest -sv test_all.py test_kubernetes.py
132132
133+
- name: Run dask tests
134+
if: ${{ matrix.executors == 'dask' && matrix.python-version != '3.8' }}
135+
run: |
136+
cd tests
137+
PYTEST_EXECUTORS=dask poetry run python -m pytest -sv test_all.py
138+
133139
webknossos_linux:
134140
needs: changes
135141
if: |

cluster_tools/Changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ For upgrade instructions, please check the respective *Breaking Changes* section
1212
### Breaking Changes
1313

1414
### Added
15+
- Added SIGINT handling to `DaskExecutor`. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)
16+
- Added support for resources (e.g. mem, cpus) to `DaskExecutor`. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)
17+
- The cluster address for the `DaskExecutor` can be configured via the `DASK_ADDRESS` env var. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)
1518

1619
### Changed
20+
- Tasks using the `DaskExecutor` are run in their own process. This is required to not block the GIL for the dask worker to communicate with the scheduler. Env variables are propagated to the task processes. [#959](https://github.com/scalableminds/webknossos-libs/pull/959)
1721

1822
### Fixed
1923

cluster_tools/cluster_tools/executors/dask.py

Lines changed: 145 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import os
2+
import re
3+
import signal
4+
import traceback
25
from concurrent import futures
36
from concurrent.futures import Future
47
from functools import partial
8+
from multiprocessing import Queue, get_context
59
from typing import (
610
TYPE_CHECKING,
711
Any,
@@ -11,9 +15,11 @@
1115
Iterator,
1216
List,
1317
Optional,
18+
Set,
1419
TypeVar,
1520
cast,
1621
)
22+
from weakref import ReferenceType, ref
1723

1824
from typing_extensions import ParamSpec
1925

@@ -28,23 +34,119 @@
2834
_S = TypeVar("_S")
2935

3036

37+
def _run_in_nanny(
38+
queue: Queue, __fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
39+
) -> None:
40+
try:
41+
__env = cast(Dict[str, str], kwargs.pop("__env"))
42+
for key, value in __env.items():
43+
os.environ[key] = value
44+
45+
ret = __fn(*args, **kwargs)
46+
queue.put({"value": ret})
47+
except Exception as exc:
48+
queue.put({"exception": exc})
49+
50+
51+
def _run_with_nanny(
52+
__fn: Callable[_P, _T],
53+
*args: _P.args,
54+
**kwargs: _P.kwargs,
55+
) -> _T:
56+
mp_context = get_context("spawn")
57+
q = mp_context.Queue()
58+
p = mp_context.Process(target=_run_in_nanny, args=(q, __fn) + args, kwargs=kwargs)
59+
p.start()
60+
p.join()
61+
ret = q.get(timeout=0.1)
62+
if "exception" in ret:
63+
raise ret["exception"]
64+
else:
65+
return ret["value"]
66+
67+
68+
def _parse_mem(size: str) -> int:
69+
units = {"": 1, "K": 2**10, "M": 2**20, "G": 2**30, "T": 2**40}
70+
m = re.match(r"^([\d\.]+)\s*([kmgtKMGT]{0,1})$", str(size).strip())
71+
assert m is not None, f"Could not parse {size}"
72+
number, unit = float(m.group(1)), m.group(2).upper()
73+
assert unit in units
74+
return int(number * units[unit])
75+
76+
77+
def _handle_kill_through_weakref(
78+
executor_ref: "ReferenceType[DaskExecutor]",
79+
existing_sigint_handler: Any,
80+
signum: Optional[int],
81+
frame: Any,
82+
) -> None:
83+
executor = executor_ref()
84+
if executor is None:
85+
return
86+
executor.handle_kill(existing_sigint_handler, signum, frame)
87+
88+
3189
class DaskExecutor(futures.Executor):
90+
"""
91+
The `DaskExecutor` allows to run workloads on a dask cluster.
92+
93+
The executor can be constructed with an existing dask `Client` or
94+
from a declarative configuration. The address of the dask scheduler
95+
can be part of the configuration or supplied as environment variable
96+
`DASK_ADDRESS`.
97+
98+
There is support for resource-based scheduling. As default, `mem` and
99+
`cpus-per-task` are supported. To make use of them, the dask workers
100+
should be started with:
101+
`python -m dask worker --no-nanny --nthreads 6 tcp://... --resources "mem=1073741824 cpus=8"`
102+
"""
103+
32104
client: "Client"
105+
pending_futures: Set[Future]
106+
job_resources: Optional[Dict[str, Any]]
107+
is_shutting_down = False
33108

34109
def __init__(
35-
self,
36-
client: "Client",
110+
self, client: "Client", job_resources: Optional[Dict[str, Any]] = None
37111
) -> None:
38112
self.client = client
113+
self.pending_futures = set()
114+
self.job_resources = job_resources
115+
116+
if self.job_resources is not None:
117+
# `mem` needs to be a number for dask, so we need to parse it
118+
if "mem" in self.job_resources:
119+
self.job_resources["mem"] = _parse_mem(self.job_resources["mem"])
120+
if "cpus-per-task" in self.job_resources:
121+
self.job_resources["cpus"] = int(
122+
self.job_resources.pop("cpus-per-task")
123+
)
124+
125+
# Clean up if a SIGINT signal is received. However, do not interfere with the
126+
# existing signal handler of the process or the
127+
# shutdown of the main process which sends SIGTERM signals to terminate all
128+
# child processes.
129+
existing_sigint_handler = signal.getsignal(signal.SIGINT)
130+
signal.signal(
131+
signal.SIGINT,
132+
partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler),
133+
)
39134

40135
@classmethod
41136
def from_config(
42137
cls,
43-
job_resources: Dict[str, Any],
138+
job_resources: Dict[str, str],
139+
**_kwargs: Any,
44140
) -> "DaskExecutor":
45141
from distributed import Client
46142

47-
return cls(Client(**job_resources))
143+
job_resources = job_resources.copy()
144+
address = job_resources.pop("address", None)
145+
if address is None:
146+
address = os.environ.get("DASK_ADDRESS", None)
147+
148+
client = Client(address=address)
149+
return cls(client, job_resources=job_resources)
48150

49151
@classmethod
50152
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
@@ -72,7 +174,20 @@ def submit( # type: ignore[override]
72174
__fn,
73175
),
74176
)
75-
fut = self.client.submit(partial(__fn, *args, **kwargs))
177+
178+
kwargs["__env"] = os.environ.copy()
179+
180+
# We run the functions in dask as a separate process to not hold the
181+
# GIL for too long, because dask workers need to be able to communicate
182+
# with the scheduler regularly.
183+
__fn = partial(_run_with_nanny, __fn)
184+
185+
fut = self.client.submit(
186+
partial(__fn, *args, **kwargs), pure=False, resources=self.job_resources
187+
)
188+
189+
self.pending_futures.add(fut)
190+
fut.add_done_callback(self.pending_futures.remove)
76191

77192
enrich_future_with_uncaught_warning(fut)
78193
return fut
@@ -125,8 +240,32 @@ def map( # type: ignore[override]
125240
def forward_log(self, fut: "Future[_T]") -> _T:
126241
return fut.result()
127242

243+
def handle_kill(
244+
self,
245+
existing_sigint_handler: Any,
246+
signum: Optional[int],
247+
frame: Any,
248+
) -> None:
249+
if self.is_shutting_down:
250+
return
251+
252+
self.is_shutting_down = True
253+
254+
self.client.cancel(list(self.pending_futures))
255+
256+
if (
257+
existing_sigint_handler # pylint: disable=comparison-with-callable
258+
!= signal.default_int_handler
259+
and callable(existing_sigint_handler) # Could also be signal.SIG_IGN
260+
):
261+
existing_sigint_handler(signum, frame)
262+
128263
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
264+
print(f"{wait=} {cancel_futures=}")
265+
traceback.print_stack()
129266
if wait:
130-
self.client.close(timeout=60 * 60 * 24)
267+
for fut in list(self.pending_futures):
268+
fut.result()
269+
self.client.close(timeout=60 * 60) # 1 hour
131270
else:
132271
self.client.close()

cluster_tools/cluster_tools/schedulers/cluster_executor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Union,
2424
cast,
2525
)
26+
from weakref import ReferenceType, ref
2627

2728
from typing_extensions import ParamSpec
2829

@@ -45,6 +46,18 @@
4546
_S = TypeVar("_S")
4647

4748

49+
def _handle_kill_through_weakref(
50+
executor_ref: "ReferenceType[ClusterExecutor]",
51+
existing_sigint_handler: Any,
52+
signum: Optional[int],
53+
frame: Any,
54+
) -> None:
55+
executor = executor_ref()
56+
if executor is None:
57+
return
58+
executor.handle_kill(existing_sigint_handler, signum, frame)
59+
60+
4861
def join_messages(strings: List[str]) -> str:
4962
return " ".join(x.strip() for x in strings if x.strip())
5063

@@ -130,7 +143,10 @@ def __init__(
130143
# shutdown of the main process which sends SIGTERM signals to terminate all
131144
# child processes.
132145
existing_sigint_handler = signal.getsignal(signal.SIGINT)
133-
signal.signal(signal.SIGINT, partial(self.handle_kill, existing_sigint_handler))
146+
signal.signal(
147+
signal.SIGINT,
148+
partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler),
149+
)
134150

135151
self.meta_data = {}
136152
assert not (

cluster_tools/tests/test_all.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from distributed import LocalCluster
1515

1616
import cluster_tools
17-
from cluster_tools.executors.dask import DaskExecutor
1817

1918

2019
# "Worker" functions.
@@ -79,10 +78,14 @@ def get_executors(with_debug_sequential: bool = False) -> List[cluster_tools.Exe
7978
executors.append(cluster_tools.get_executor("sequential"))
8079
if "dask" in executor_keys:
8180
if not _dask_cluster:
82-
from distributed import LocalCluster
81+
from distributed import LocalCluster, Worker
8382

84-
_dask_cluster = LocalCluster()
85-
executors.append(cluster_tools.get_executor("dask", address=_dask_cluster))
83+
_dask_cluster = LocalCluster(
84+
worker_class=Worker, resources={"mem": 20e9, "cpus": 4}, nthreads=6
85+
)
86+
executors.append(
87+
cluster_tools.get_executor("dask", job_resources={"address": _dask_cluster})
88+
)
8689
if "test_pickling" in executor_keys:
8790
executors.append(cluster_tools.get_executor("test_pickling"))
8891
if "pbs" in executor_keys:
@@ -328,7 +331,8 @@ def run_map(executor: cluster_tools.Executor) -> None:
328331
assert list(result) == [4, 9, 16]
329332

330333
for exc in get_executors():
331-
run_map(exc)
334+
if not isinstance(exc, cluster_tools.DaskExecutor):
335+
run_map(exc)
332336

333337

334338
def test_executor_args() -> None:

0 commit comments

Comments
 (0)