|
1 | 1 | import os |
| 2 | +import re |
| 3 | +import signal |
| 4 | +import traceback |
2 | 5 | from concurrent import futures |
3 | 6 | from concurrent.futures import Future |
4 | 7 | from functools import partial |
| 8 | +from multiprocessing import Queue, get_context |
5 | 9 | from typing import ( |
6 | 10 | TYPE_CHECKING, |
7 | 11 | Any, |
|
11 | 15 | Iterator, |
12 | 16 | List, |
13 | 17 | Optional, |
| 18 | + Set, |
14 | 19 | TypeVar, |
15 | 20 | cast, |
16 | 21 | ) |
| 22 | +from weakref import ReferenceType, ref |
17 | 23 |
|
18 | 24 | from typing_extensions import ParamSpec |
19 | 25 |
|
|
28 | 34 | _S = TypeVar("_S") |
29 | 35 |
|
30 | 36 |
|
| 37 | +def _run_in_nanny( |
| 38 | + queue: Queue, __fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs |
| 39 | +) -> None: |
| 40 | + try: |
| 41 | + __env = cast(Dict[str, str], kwargs.pop("__env")) |
| 42 | + for key, value in __env.items(): |
| 43 | + os.environ[key] = value |
| 44 | + |
| 45 | + ret = __fn(*args, **kwargs) |
| 46 | + queue.put({"value": ret}) |
| 47 | + except Exception as exc: |
| 48 | + queue.put({"exception": exc}) |
| 49 | + |
| 50 | + |
| 51 | +def _run_with_nanny( |
| 52 | + __fn: Callable[_P, _T], |
| 53 | + *args: _P.args, |
| 54 | + **kwargs: _P.kwargs, |
| 55 | +) -> _T: |
| 56 | + mp_context = get_context("spawn") |
| 57 | + q = mp_context.Queue() |
| 58 | + p = mp_context.Process(target=_run_in_nanny, args=(q, __fn) + args, kwargs=kwargs) |
| 59 | + p.start() |
| 60 | + p.join() |
| 61 | + ret = q.get(timeout=0.1) |
| 62 | + if "exception" in ret: |
| 63 | + raise ret["exception"] |
| 64 | + else: |
| 65 | + return ret["value"] |
| 66 | + |
| 67 | + |
| 68 | +def _parse_mem(size: str) -> int: |
| 69 | + units = {"": 1, "K": 2**10, "M": 2**20, "G": 2**30, "T": 2**40} |
| 70 | + m = re.match(r"^([\d\.]+)\s*([kmgtKMGT]{0,1})$", str(size).strip()) |
| 71 | + assert m is not None, f"Could not parse {size}" |
| 72 | + number, unit = float(m.group(1)), m.group(2).upper() |
| 73 | + assert unit in units |
| 74 | + return int(number * units[unit]) |
| 75 | + |
| 76 | + |
| 77 | +def _handle_kill_through_weakref( |
| 78 | + executor_ref: "ReferenceType[DaskExecutor]", |
| 79 | + existing_sigint_handler: Any, |
| 80 | + signum: Optional[int], |
| 81 | + frame: Any, |
| 82 | +) -> None: |
| 83 | + executor = executor_ref() |
| 84 | + if executor is None: |
| 85 | + return |
| 86 | + executor.handle_kill(existing_sigint_handler, signum, frame) |
| 87 | + |
| 88 | + |
31 | 89 | class DaskExecutor(futures.Executor): |
| 90 | + """ |
| 91 | + The `DaskExecutor` allows to run workloads on a dask cluster. |
| 92 | +
|
| 93 | + The executor can be constructed with an existing dask `Client` or |
| 94 | + from a declarative configuration. The address of the dask scheduler |
| 95 | + can be part of the configuration or supplied as environment variable |
| 96 | + `DASK_ADDRESS`. |
| 97 | +
|
| 98 | + There is support for resource-based scheduling. As default, `mem` and |
| 99 | + `cpus-per-task` are supported. To make use of them, the dask workers |
| 100 | + should be started with: |
| 101 | + `python -m dask worker --no-nanny --nthreads 6 tcp://... --resources "mem=1073741824 cpus=8"` |
| 102 | + """ |
| 103 | + |
32 | 104 | client: "Client" |
| 105 | + pending_futures: Set[Future] |
| 106 | + job_resources: Optional[Dict[str, Any]] |
| 107 | + is_shutting_down = False |
33 | 108 |
|
34 | 109 | def __init__( |
35 | | - self, |
36 | | - client: "Client", |
| 110 | + self, client: "Client", job_resources: Optional[Dict[str, Any]] = None |
37 | 111 | ) -> None: |
38 | 112 | self.client = client |
| 113 | + self.pending_futures = set() |
| 114 | + self.job_resources = job_resources |
| 115 | + |
| 116 | + if self.job_resources is not None: |
| 117 | + # `mem` needs to be a number for dask, so we need to parse it |
| 118 | + if "mem" in self.job_resources: |
| 119 | + self.job_resources["mem"] = _parse_mem(self.job_resources["mem"]) |
| 120 | + if "cpus-per-task" in self.job_resources: |
| 121 | + self.job_resources["cpus"] = int( |
| 122 | + self.job_resources.pop("cpus-per-task") |
| 123 | + ) |
| 124 | + |
| 125 | + # Clean up if a SIGINT signal is received. However, do not interfere with the |
| 126 | + # existing signal handler of the process or the |
| 127 | + # shutdown of the main process which sends SIGTERM signals to terminate all |
| 128 | + # child processes. |
| 129 | + existing_sigint_handler = signal.getsignal(signal.SIGINT) |
| 130 | + signal.signal( |
| 131 | + signal.SIGINT, |
| 132 | + partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler), |
| 133 | + ) |
39 | 134 |
|
40 | 135 | @classmethod |
41 | 136 | def from_config( |
42 | 137 | cls, |
43 | | - job_resources: Dict[str, Any], |
| 138 | + job_resources: Dict[str, str], |
| 139 | + **_kwargs: Any, |
44 | 140 | ) -> "DaskExecutor": |
45 | 141 | from distributed import Client |
46 | 142 |
|
47 | | - return cls(Client(**job_resources)) |
| 143 | + job_resources = job_resources.copy() |
| 144 | + address = job_resources.pop("address", None) |
| 145 | + if address is None: |
| 146 | + address = os.environ.get("DASK_ADDRESS", None) |
| 147 | + |
| 148 | + client = Client(address=address) |
| 149 | + return cls(client, job_resources=job_resources) |
48 | 150 |
|
49 | 151 | @classmethod |
50 | 152 | def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]: |
@@ -72,7 +174,20 @@ def submit( # type: ignore[override] |
72 | 174 | __fn, |
73 | 175 | ), |
74 | 176 | ) |
75 | | - fut = self.client.submit(partial(__fn, *args, **kwargs)) |
| 177 | + |
| 178 | + kwargs["__env"] = os.environ.copy() |
| 179 | + |
| 180 | + # We run the functions in dask as a separate process to not hold the |
| 181 | + # GIL for too long, because dask workers need to be able to communicate |
| 182 | + # with the scheduler regularly. |
| 183 | + __fn = partial(_run_with_nanny, __fn) |
| 184 | + |
| 185 | + fut = self.client.submit( |
| 186 | + partial(__fn, *args, **kwargs), pure=False, resources=self.job_resources |
| 187 | + ) |
| 188 | + |
| 189 | + self.pending_futures.add(fut) |
| 190 | + fut.add_done_callback(self.pending_futures.remove) |
76 | 191 |
|
77 | 192 | enrich_future_with_uncaught_warning(fut) |
78 | 193 | return fut |
@@ -125,8 +240,32 @@ def map( # type: ignore[override] |
125 | 240 | def forward_log(self, fut: "Future[_T]") -> _T: |
126 | 241 | return fut.result() |
127 | 242 |
|
| 243 | + def handle_kill( |
| 244 | + self, |
| 245 | + existing_sigint_handler: Any, |
| 246 | + signum: Optional[int], |
| 247 | + frame: Any, |
| 248 | + ) -> None: |
| 249 | + if self.is_shutting_down: |
| 250 | + return |
| 251 | + |
| 252 | + self.is_shutting_down = True |
| 253 | + |
| 254 | + self.client.cancel(list(self.pending_futures)) |
| 255 | + |
| 256 | + if ( |
| 257 | + existing_sigint_handler # pylint: disable=comparison-with-callable |
| 258 | + != signal.default_int_handler |
| 259 | + and callable(existing_sigint_handler) # Could also be signal.SIG_IGN |
| 260 | + ): |
| 261 | + existing_sigint_handler(signum, frame) |
| 262 | + |
128 | 263 | def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: |
| 264 | + print(f"{wait=} {cancel_futures=}") |
| 265 | + traceback.print_stack() |
129 | 266 | if wait: |
130 | | - self.client.close(timeout=60 * 60 * 24) |
| 267 | + for fut in list(self.pending_futures): |
| 268 | + fut.result() |
| 269 | + self.client.close(timeout=60 * 60) # 1 hour |
131 | 270 | else: |
132 | 271 | self.client.close() |
0 commit comments