Skip to content

Commit 10fc630

Browse files
authored
non-fork start_method as default (#662)
* non-fork start_method as default * tests * fixes * cleanup * fixes * new slurm image * env var * debugging * remove prints * fixes? * tests * changelog * revert * unix-paths for allowed_hosts * lint * topfschlagen * fix block-network * pr feedback * ignore types * fix types
1 parent e267f18 commit 10fc630

File tree

4 files changed

+47
-38
lines changed

4 files changed

+47
-38
lines changed

cluster_tools/Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ For upgrade instructions, please check the respective *Breaking Changes* section
1010
[Commits](https://github.com/scalableminds/webknossos-libs/compare/v0.9.14...HEAD)
1111

1212
### Breaking Changes
13+
- The `multiprocessing` executor now uses `spawn` as default start method. `fork` and `forkserver` can be used by supplying a `start_method` argument (e.g. `cluster_tools.get_executor("multiprocessing", start_method="forkserver")`) or by setting the `MULTIPROCESSING_DEFAULT_START_METHOD` environment variable. [#662](https://github.com/scalableminds/webknossos-libs/pull/662)
1314

1415
### Added
1516

cluster_tools/cluster_tools/__init__.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_existent_kwargs_subset(whitelist, kwargs):
2525
return new_kwargs
2626

2727

28-
PROCESS_POOL_KWARGS_WHITELIST = ["max_workers", "mp_context", "initializer", "initargs"]
28+
PROCESS_POOL_KWARGS_WHITELIST = ["max_workers", "initializer", "initargs"]
2929

3030

3131
class WrappedProcessPoolExecutor(ProcessPoolExecutor):
@@ -37,31 +37,28 @@ class WrappedProcessPoolExecutor(ProcessPoolExecutor):
3737
"""
3838

3939
def __init__(self, **kwargs):
40-
new_kwargs = get_existent_kwargs_subset(PROCESS_POOL_KWARGS_WHITELIST, kwargs)
40+
assert (not "start_method" in kwargs or kwargs["start_method"] is None) or (
41+
not "mp_context" in kwargs
42+
), "Cannot use both `start_method` and `mp_context` kwargs."
4143

42-
self.did_overwrite_start_method = False
43-
if kwargs.get("start_method", None) is not None:
44-
self.did_overwrite_start_method = True
45-
self.old_start_method = multiprocessing.get_start_method()
46-
start_method = kwargs["start_method"]
47-
logging.info(
48-
f"Overwriting start_method to {start_method}. Previous value: {self.old_start_method}"
49-
)
50-
multiprocessing.set_start_method(start_method, force=True)
44+
new_kwargs = get_existent_kwargs_subset(PROCESS_POOL_KWARGS_WHITELIST, kwargs)
5145

52-
ProcessPoolExecutor.__init__(self, **new_kwargs)
46+
mp_context = None
5347

54-
def shutdown(self, *args, **kwargs):
48+
if "mp_context" in kwargs:
49+
mp_context = kwargs["mp_context"]
50+
elif "start_method" in kwargs and kwargs["start_method"] is not None:
51+
mp_context = multiprocessing.get_context(kwargs["start_method"])
52+
elif "MULTIPROCESSING_DEFAULT_START_METHOD" in os.environ:
53+
mp_context = multiprocessing.get_context(
54+
os.environ["MULTIPROCESSING_DEFAULT_START_METHOD"]
55+
)
56+
else:
57+
mp_context = multiprocessing.get_context("spawn")
5558

56-
super().shutdown(*args, **kwargs)
59+
new_kwargs["mp_context"] = mp_context
5760

58-
if self.did_overwrite_start_method:
59-
logging.info(
60-
f"Restoring start_method to original value: {self.old_start_method}."
61-
)
62-
multiprocessing.set_start_method(self.old_start_method, force=True)
63-
self.old_start_method = None
64-
self.did_overwrite_start_method = False
61+
ProcessPoolExecutor.__init__(self, **new_kwargs)
6562

6663
def submit(self, *args, **kwargs):
6764

@@ -88,7 +85,7 @@ def submit(self, *args, **kwargs):
8885
# where wrapper_fn_1 is called, which eventually calls wrapper_fn_2, which eventually calls actual_fn.
8986
call_stack = []
9087

91-
if multiprocessing.get_start_method() != "fork":
88+
if self._mp_context.get_start_method() != "fork":
9289
# If a start_method other than the default "fork" is used, logging needs to be re-setup,
9390
# because the programming context is not inherited in those cases.
9491
multiprocessing_logging_setup_fn = get_multiprocessing_logging_setup_fn()

cluster_tools/tests/test_multiprocessing.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ def expect_fork():
1414
return True
1515

1616

17+
def expect_forkserver():
18+
assert mp.get_start_method() == "forkserver"
19+
return True
20+
21+
1722
def expect_spawn():
1823
assert mp.get_start_method() == "spawn"
1924
return True
@@ -22,22 +27,29 @@ def expect_spawn():
2227
def test_map_with_spawn():
2328
with cluster_tools.get_executor("multiprocessing", max_workers=5) as executor:
2429
assert executor.submit(
25-
expect_fork
26-
).result(), "Multiprocessing should use fork by default"
30+
expect_spawn
31+
).result(), "Multiprocessing should use `spawn` by default"
2732

2833
with cluster_tools.get_executor(
2934
"multiprocessing", max_workers=5, start_method=None
3035
) as executor:
3136
assert executor.submit(
32-
expect_fork
33-
).result(), "Multiprocessing should use fork if start_method is None"
37+
expect_spawn
38+
).result(), "Multiprocessing should use `spawn` if start_method is None"
3439

3540
with cluster_tools.get_executor(
36-
"multiprocessing", max_workers=5, start_method="spawn"
41+
"multiprocessing", max_workers=5, start_method="forkserver"
3742
) as executor:
3843
assert executor.submit(
39-
expect_spawn
40-
).result(), "Multiprocessing should use spawn if requested"
44+
expect_forkserver
45+
).result(), "Multiprocessing should use `forkserver` if requested"
46+
47+
with cluster_tools.get_executor(
48+
"multiprocessing", max_workers=5, start_method="fork"
49+
) as executor:
50+
assert executor.submit(
51+
expect_fork
52+
).result(), "Multiprocessing should use `fork` if requested"
4153

4254

4355
def accept_high_mem(data):

webknossos/tests/conftest.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import multiprocessing
32
import re
43
import warnings
54
from io import BytesIO
@@ -318,11 +317,11 @@ def pytest_collection_modifyitems(items: List[pytest.Item]) -> None:
318317
if item.get_closest_marker("vcr") is None:
319318
item.add_marker("vcr")
320319

321-
if (
322-
item.get_closest_marker("block_network") is None
323-
and multiprocessing.get_start_method() != "fork"
324-
):
325-
# To allow for UNIX socket communication necessary for spawn multiprocessing
326-
# addresses starting with `/` are allowed
327-
marker = pytest.mark.block_network(allowed_hosts=["/.*"])
328-
item.add_marker(marker)
320+
# To allow for UNIX socket communication necessary for spawn multiprocessing
321+
# addresses starting with `/` are allowed
322+
marker = item.get_closest_marker("block_network")
323+
if marker is None:
324+
new_marker = pytest.mark.block_network(allowed_hosts=["/.*"])
325+
item.add_marker(new_marker)
326+
else:
327+
marker.kwargs["allowed_hosts"].append("/.*")

0 commit comments

Comments
 (0)