Skip to content

Commit e99df50

Browse files
authored
Make slurm job canceling more robust (#1317)
* Make slurm job canceling more robust, by killing all remaining jobs after a short wait time * [Debug] Log times for scancel handling * Fix signal handling for cluster executors which stopped working if multiple executors were instantiated in the same process before. * First stop the file wait thread before canceling the jobs to avoid checking the job state after canceling which will log lots of errors for jobs that were canceled before running, because they don't turn up in the slurm accounting * Cleanup * Properly mock (and restore) env variables * Update changelog * Add f-string * Exempt completing jobs when checking whether cancellation worked fast enough * Test whether monkeypatching works * add time logs to find out where job cancellation time is spent * Fix setting sigterm_wait_in_s to 0 during tests * Improve slurm cancellation test to assert that original sigint handler was called * Format * Fix typing * Fix nonlocal variable access * Add test for signal handling regression when multiple executors are instantiated * Garbage collect after first executor ran to provoke regression * Apply some PR feedback * Delete executor1 in test to provoke bug * Add pytest-timeout to avoid hanging tests and wait for futures in test * Cleanup and fix hanging tests * Add comment * Remove pytest-timeout dependency again * Restore uv.lock * When shutting down cluster executor and wait if False, treat as if executor was killed * Decrease SIGTERM_WAIT_IN_S for new test and assert that shutdown hooks are cleaned up * Format * Remove dask executor from cluster tools * Linting * Update changelog * Actually deregister shutdown hook and use with statements to ensure executor shutdown * Also update webknossos uv.lock * Add pytest-timeout * Unify the two variables tracking executor shutdown * Fix kubernetes dependency * Assert that no jobs run before the tests * Add debug logging * Fix signal handling test * Revert "Update changelog" This reverts commit 31c472a. * Revert "Remove dask executor from cluster tools" This reverts commit 9866368. * Revert "Linting" This reverts commit 4ee81c0.
1 parent a887e42 commit e99df50

File tree

9 files changed

+615
-459
lines changed

9 files changed

+615
-459
lines changed

cluster_tools/Changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ For upgrade instructions, please check the respective *Breaking Changes* section
1616
### Changed
1717

1818
### Fixed
19+
- Fixed that sometimes not all slurm jobs were canceled when an executor was killed. [#1317](https://github.com/scalableminds/webknossos-libs/pull/1317)
20+
- Fixed that when multiple cluster executors were instantiated in the same process, the original SIGINT handler sometimes was no longer called, leading to the main application not shutting down correctly after a SIGINT signal. [#1317](https://github.com/scalableminds/webknossos-libs/pull/1317)
1921

2022

2123
## [2.4.4](https://github.com/scalableminds/webknossos-libs/releases/tag/v2.4.4) - 2025-07-14

cluster_tools/cluster_tools/schedulers/cluster_executor.py

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from concurrent import futures
1010
from concurrent.futures import Future
1111
from functools import partial
12+
from types import FrameType, TracebackType
1213
from typing import (
1314
Any,
1415
Literal,
1516
TypeVar,
1617
cast,
1718
)
18-
from weakref import ReferenceType, ref
1919

2020
from typing_extensions import ParamSpec
2121

@@ -38,18 +38,6 @@
3838
_S = TypeVar("_S")
3939

4040

41-
def _handle_kill_through_weakref(
42-
executor_ref: "ReferenceType[ClusterExecutor]",
43-
existing_sigint_handler: Any,
44-
signum: int | None,
45-
frame: Any,
46-
) -> None:
47-
executor = executor_ref()
48-
if executor is None:
49-
return
50-
executor.handle_kill(existing_sigint_handler, signum, frame)
51-
52-
5341
def join_messages(strings: list[str]) -> str:
5442
return " ".join(x.strip() for x in strings if x.strip())
5543

@@ -79,6 +67,9 @@ class RemoteTimeLimitException(RemoteResourceLimitException):
7967
class ClusterExecutor(futures.Executor):
8068
"""Futures executor for executing jobs on a cluster."""
8169

70+
_shutdown_hooks: list[Callable[[], None]] = []
71+
_installed_signal_handler: bool = False
72+
8273
def __init__(
8374
self,
8475
debug: bool = False,
@@ -103,7 +94,6 @@ def __init__(
10394
self.job_resources = job_resources
10495
self.additional_setup_lines = additional_setup_lines or []
10596
self.job_name = job_name
106-
self.was_requested_to_shutdown = False
10797
self.cfut_dir = (
10898
cfut_dir if cfut_dir is not None else os.getenv("CFUT_DIR", ".cfut")
10999
)
@@ -130,15 +120,7 @@ def __init__(
130120

131121
os.makedirs(self.cfut_dir, exist_ok=True)
132122

133-
# Clean up if a SIGINT signal is received. However, do not interfere with the
134-
# existing signal handler of the process or the
135-
# shutdown of the main process which sends SIGTERM signals to terminate all
136-
# child processes.
137-
existing_sigint_handler = signal.getsignal(signal.SIGINT)
138-
signal.signal(
139-
signal.SIGINT,
140-
partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler),
141-
)
123+
self._register_shutdown_hook(self.handle_kill)
142124

143125
self.metadata = {}
144126
assert not ("logging_config" in kwargs and "logging_setup_fn" in kwargs), (
@@ -158,26 +140,81 @@ def as_completed(cls, futs: list[Future[_T]]) -> Iterator[Future[_T]]:
158140
def executor_key(cls) -> str:
159141
pass
160142

161-
def handle_kill(
162-
self, existing_sigint_handler: Any, signum: int | None, frame: Any
143+
@classmethod
144+
def _ensure_signal_handlers_are_installed(cls) -> None:
145+
# Only overwrite the signal handler once
146+
if cls._installed_signal_handler:
147+
return
148+
149+
# Clean up if a SIGINT or SIGTERM signal is received. However, do not
150+
# interfere with the existing signal handler of the process and execute
151+
# it afterwards.
152+
existing_sigint_handler = signal.getsignal(signal.SIGINT)
153+
signal.signal(
154+
signal.SIGINT,
155+
partial(cls._handle_shutdown, existing_sigint_handler),
156+
)
157+
existing_sigterm_handler = signal.getsignal(signal.SIGTERM)
158+
signal.signal(
159+
signal.SIGTERM,
160+
partial(cls._handle_shutdown, existing_sigterm_handler),
161+
)
162+
163+
cls._installed_signal_handler = True
164+
165+
@classmethod
166+
def _register_shutdown_hook(cls, hook: Callable[[], None]) -> None:
167+
cls._shutdown_hooks.append(hook)
168+
cls._ensure_signal_handlers_are_installed()
169+
170+
@classmethod
171+
def _deregister_shutdown_hook(cls, hook: Callable[[], None]) -> None:
172+
if hook in cls._shutdown_hooks:
173+
cls._shutdown_hooks.remove(hook)
174+
else:
175+
logging.warning(
176+
"Cannot deregister executors shutdown hook since it's not registered."
177+
)
178+
179+
@classmethod
180+
def _handle_shutdown(
181+
cls,
182+
existing_signal_handler: Callable[[int, FrameType | None], None] | int | None,
183+
signum: int,
184+
frame: Any,
163185
) -> None:
186+
logging.critical(
187+
f"[{cls.__name__}] Caught signal {signal.Signals(signum).name}, running shutdown hooks"
188+
)
189+
try:
190+
for hook in cls._shutdown_hooks:
191+
hook()
192+
except Exception as e:
193+
print(f"Error during shutdown: {e}")
194+
195+
if (
196+
callable(existing_signal_handler)
197+
and existing_signal_handler
198+
not in (
199+
signal.SIG_DFL, # For completeness sake (since it's not callable anyways). The system's default signal handler
200+
signal.SIG_IGN, # For completeness sake (since it's not callable anyways). The instruction to ignore a signal
201+
signal.default_int_handler, # Python's default SIGINT handler
202+
)
203+
):
204+
existing_signal_handler(signum, frame)
205+
206+
def handle_kill(self) -> None:
164207
if self.is_shutting_down:
165208
return
166209

167210
self.is_shutting_down = True
168211

169-
self.inner_handle_kill(signum, frame)
170212
self.wait_thread.stop()
213+
self.inner_handle_kill()
171214
self.clean_up()
172215

173-
if (
174-
existing_sigint_handler != signal.default_int_handler
175-
and callable(existing_sigint_handler) # Could also be signal.SIG_IGN
176-
):
177-
existing_sigint_handler(signum, frame)
178-
179216
@abstractmethod
180-
def inner_handle_kill(self, _signum: Any, _frame: Any) -> None:
217+
def inner_handle_kill(self) -> None:
181218
pass
182219

183220
@abstractmethod
@@ -363,9 +400,9 @@ def _completion(self, jobid: str, failed_early: bool) -> None:
363400
self._maybe_mark_logs_for_cleanup(jobid)
364401

365402
def ensure_not_shutdown(self) -> None:
366-
if self.was_requested_to_shutdown:
403+
if self.is_shutting_down:
367404
raise RuntimeError(
368-
"submit() was invoked on a ClusterExecutor instance even though shutdown() was executed for that instance."
405+
"submit() was invoked on a ClusterExecutor instance even though shutdown() or handle_kill() was executed for that instance."
369406
)
370407

371408
def create_enriched_future(self) -> Future:
@@ -591,17 +628,35 @@ def register_jobs(
591628
should_keep_output,
592629
)
593630

631+
# Overwrite the context manager __exit as it doesn't forward the information whether an exception was thrown or not otherwise
632+
# which may lead to a deadlock if an exception is thrown within a cluster executor with statement, because self.jobs_empty_cond.wait()
633+
# never succeeds.
634+
def __exit__(
635+
self,
636+
exc_type: type[BaseException] | None,
637+
_exc_val: BaseException | None,
638+
_exc_tb: TracebackType | None,
639+
) -> Literal[False]:
640+
# Don't wait if an exception was thrown
641+
self.shutdown(wait=exc_type is None)
642+
return False
643+
594644
def shutdown(self, wait: bool = True, cancel_futures: bool = True) -> None:
595645
"""Close the pool."""
646+
if self.is_shutting_down:
647+
return
648+
649+
self.is_shutting_down = True
596650
if not cancel_futures:
597651
logging.warning(
598652
"The provided cancel_futures argument is ignored by ClusterExecutor."
599653
)
600-
self.was_requested_to_shutdown = True
601654
if wait:
602655
with self.jobs_lock:
603656
if self.jobs and self.wait_thread.is_alive():
604657
self.jobs_empty_cond.wait()
658+
else:
659+
self.inner_handle_kill()
605660

606661
self.wait_thread.stop()
607662
self.wait_thread.join()
@@ -617,6 +672,7 @@ def clean_up(self) -> None:
617672
f"Could not delete file during clean up. Path: {file_to_clean_up} Exception: {exc}. Continuing..."
618673
)
619674
self.files_to_clean_up = []
675+
self._deregister_shutdown_hook(self.handle_kill)
620676

621677
def map( # type: ignore[override]
622678
self,

cluster_tools/cluster_tools/schedulers/kube.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,7 @@ def get_job_id_string(cls) -> str:
113113
return job_id
114114
return cls.get_jobid_with_index(job_id, job_index)
115115

116-
def inner_handle_kill(
117-
self,
118-
*args: Any, # noqa: ARG002 Unused method argument: `args`
119-
**kwargs: Any, # noqa: ARG002 Unused method argument: `kwargs`
120-
) -> None:
116+
def inner_handle_kill(self) -> None:
121117
job_ids = ",".join(str(job_id) for job_id in self.jobs.keys())
122118

123119
print(

cluster_tools/cluster_tools/schedulers/pbs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66
from concurrent.futures import Future
7-
from typing import Any, Literal
7+
from typing import Literal
88

99
from cluster_tools._utils.call import call, chcall
1010
from cluster_tools._utils.string_ import random_string
@@ -56,7 +56,7 @@ def format_log_file_name(job_id_with_index: str, suffix: str = ".stdout") -> str
5656
def get_job_id_string(cls) -> str:
5757
return cls.get_current_job_id()
5858

59-
def inner_handle_kill(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 Unused method argument: `args`, kwargs
59+
def inner_handle_kill(self) -> None:
6060
scheduled_job_ids: list[int | str] = list(self.jobs.keys())
6161

6262
if len(scheduled_job_ids):

cluster_tools/cluster_tools/schedulers/slurm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def submit_text(cls, job: str, cfut_dir: str) -> str:
244244

245245
return str(int(job_id)) # int() ensures coherent parsing
246246

247-
def inner_handle_kill(self, *args: Any, **kwargs: Any) -> None: # noqa ARG002 Unused method argument: `args`, kwargs
247+
def inner_handle_kill(self) -> None:
248248
for submit_thread in self.submit_threads:
249249
submit_thread.stop()
250250

@@ -260,10 +260,15 @@ def inner_handle_kill(self, *args: Any, **kwargs: Any) -> None: # noqa ARG002 U
260260
# but can be canceled together using the job_id.
261261
unique_job_ids = set(map(lambda x: str(x).split("_")[0], scheduled_job_ids))
262262
job_id_string = " ".join(unique_job_ids)
263+
# Allow to speed up the shutdown, for example, when running voxelytics locally
264+
sigterm_wait_in_s_env = float(os.environ.get("SIGTERM_WAIT_IN_S", 5))
263265
# Send SIGINT signal to running jobs instead of terminating the jobs right away. This way, the jobs can
264266
# react to the signal, safely shutdown and signal (cancel) jobs they possibly scheduled, recursively.
267+
# After a short waiting time kill all jobs that are still running (due to race conditions or because they
268+
# didn't react to the SIGINT signal for some reason).
265269
_, stderr, _ = call(
266-
f"scancel --state=PENDING {job_id_string}; scancel -s SIGINT --state=RUNNING {job_id_string}; scancel --state=SUSPENDED {job_id_string}"
270+
f"scancel --state=PENDING {job_id_string}; scancel -s SIGINT --state=RUNNING {job_id_string};"
271+
+ f"scancel --state=SUSPENDED {job_id_string}; sleep {sigterm_wait_in_s_env}; scancel {job_id_string}"
267272
)
268273

269274
maybe_error_or_warning = (

cluster_tools/pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ Changelog = "https://github.com/scalableminds/webknossos-libs/blob/master/cluste
1717

1818

1919
[project.optional-dependencies]
20-
kubernetes = ["distributed ~=2023.9.1"]
21-
dask = ["kubernetes ~=27.2.0"]
20+
dask = ["distributed ~=2023.9.1"]
21+
kubernetes = ["kubernetes ~=27.2.0"]
2222
all=["cluster_tools[kubernetes]", "cluster_tools[dask]"]
2323

2424
[tool.uv]
2525
dev-dependencies = [
2626
"icecream ~=2.1.1",
2727
"mypy ~=1.15.0",
2828
"pytest ~=8.3.3",
29+
"pytest-timeout>=2.4.0",
2930
"ruff ~=0.11.0",
3031
]
3132

0 commit comments

Comments
 (0)