Skip to content

Commit 1733a0a

Browse files
client - only initialize shared resources if current_process().name == 'MainProcess'
1 parent ce3fd60 commit 1733a0a

File tree

2 files changed

+100
-16
lines changed

2 files changed

+100
-16
lines changed

src/geophires_x_client/__init__.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import atexit
2-
import os
32
import sys
43
import threading
54
from multiprocessing import Manager
6-
from pathlib import Path
5+
from multiprocessing import current_process
76

87
# noinspection PyPep8Naming
98
from geophires_x import GEOPHIRESv3 as geophires
@@ -36,24 +35,30 @@ def __init__(self, enable_caching=True, logger_name=None):
3635
self._logger = _get_logger(logger_name=logger_name)
3736
self._enable_caching = enable_caching
3837

39-
# Lazy-initialize shared resources if they haven't been already.
4038
if enable_caching and GeophiresXClient._manager is None:
39+
# Lazy-initialize shared resources if they haven't been already.
4140
self._initialize_shared_resources()
4241

4342
@classmethod
4443
def _initialize_shared_resources(cls):
4544
"""
4645
Initializes the multiprocessing Manager and shared resources in a
47-
thread-safe manner. It also registers the shutdown hook to ensure
48-
automatic cleanup on application exit.
46+
thread-safe and now process-safe manner. It also registers the
47+
shutdown hook to ensure automatic cleanup on application exit.
4948
"""
50-
with cls._init_lock:
51-
if cls._manager is None:
52-
cls._manager = Manager()
53-
cls._cache = cls._manager.dict()
54-
cls._lock = cls._manager.RLock()
55-
# Register the shutdown method to be called automatically on exit.
56-
atexit.register(cls.shutdown)
49+
# Ensure that only the top-level user process can create the manager.
50+
# A spawned child process, which re-imports this script, will have a different name
51+
# (e.g., 'Spawn-1') and will skip this entire block, preventing a recursive crash.
52+
if current_process().name == 'MainProcess':
53+
with cls._init_lock:
54+
if cls._manager is None:
55+
cls._logger = _get_logger(__name__) # Add a logger for this class method
56+
cls._logger.debug('MainProcess is creating the shared multiprocessing manager...')
57+
cls._manager = Manager()
58+
cls._cache = cls._manager.dict()
59+
cls._lock = cls._manager.RLock()
60+
# Register the shutdown method to be called automatically on exit.
61+
atexit.register(cls.shutdown)
5762

5863
@classmethod
5964
def shutdown(cls):
@@ -65,9 +70,17 @@ def shutdown(cls):
6570
"""
6671
with cls._init_lock:
6772
if cls._manager is not None:
73+
cls._logger = _get_logger(__name__)
74+
cls._logger.debug('Shutting down the shared multiprocessing manager...')
6875
cls._manager.shutdown()
6976
# De-register the hook to avoid trying to shut down twice.
70-
atexit.unregister(cls.shutdown)
77+
try:
78+
atexit.unregister(cls.shutdown)
79+
except Exception as e:
80+
# Fails in some environments (e.g. pytest), but is not critical
81+
cls._logger.debug(
82+
f'Encountered exception shutting down the shared multiprocessing manager (OK): ' f'{e!s}'
83+
)
7184
cls._manager = None
7285
cls._cache = None
7386
cls._lock = None
@@ -80,22 +93,23 @@ def get_geophires_result(self, input_params: GeophiresInputParameters) -> Geophi
8093
"""
8194
is_immutable = isinstance(input_params, ImmutableGeophiresInputParameters)
8295

83-
if not (self._enable_caching and is_immutable):
96+
if not (self._enable_caching and is_immutable and GeophiresXClient._manager is not None):
8497
return self._run_simulation(input_params)
8598

8699
cache_key = hash(input_params)
87100

88101
with GeophiresXClient._lock:
89102
if cache_key in GeophiresXClient._cache:
103+
# self._logger.debug(f'Cache hit for inputs: {input_params}')
90104
return GeophiresXClient._cache[cache_key]
91105

106+
# Cache miss
92107
result = self._run_simulation(input_params)
93108
GeophiresXClient._cache[cache_key] = result
94109
return result
95110

96111
def _run_simulation(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
97112
"""Helper method to encapsulate the actual GEOPHIRES run."""
98-
stash_cwd = Path.cwd()
99113
stash_sys_argv = sys.argv
100114
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
101115

@@ -107,7 +121,6 @@ def _run_simulation(self, input_params: GeophiresInputParameters) -> GeophiresXR
107121
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
108122
finally:
109123
sys.argv = stash_sys_argv
110-
os.chdir(stash_cwd)
111124

112125
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
113126
result = GeophiresXResult(input_params.get_output_file_path())
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# ruff: noqa: S603
2+
3+
import subprocess
4+
import sys
5+
import tempfile
6+
from pathlib import Path
7+
8+
from base_test_case import BaseTestCase
9+
10+
11+
class GeophiresClientImperativeInstantiationTestCase(BaseTestCase):
12+
13+
# noinspection PyMethodMayBeStatic
14+
def test_imperative_instantiation_in_subprocess(self):
15+
"""
16+
Verifies that GeophiresXClient can be instantiated at the global scope
17+
in a script without causing a multiprocessing-related RuntimeError.
18+
19+
This test directly simulates the failure condition by writing and executing
20+
a separate Python script as a subprocess. This ensures that the fix
21+
(checking for 'MainProcess') is working correctly on systems that use
22+
the 'spawn' start method for multiprocessing (like macOS and Windows).
23+
"""
24+
project_root = Path(__file__).parent.parent.resolve()
25+
26+
script_content = f"""
27+
import sys
28+
# We must add the project root to the path for the import to work.
29+
sys.path.insert(0, r'{project_root}')
30+
31+
from geophires_x_client import GeophiresXClient
32+
33+
print("Attempting to instantiate GeophiresXClient at the global scope...")
34+
35+
# This is the line that would have previously crashed with a RuntimeError.
36+
client = GeophiresXClient()
37+
38+
print("Instantiation successful.")
39+
40+
# It is critical to shut down the client to release the manager process,
41+
# otherwise it can linger and interfere with other tests in the suite.
42+
GeophiresXClient.shutdown()
43+
44+
print("Shutdown successful.")
45+
46+
# A final message to confirm the script completed without errors.
47+
print("SUCCESS")
48+
"""
49+
50+
with tempfile.TemporaryDirectory() as tmpdir:
51+
test_script_path = Path(tmpdir) / 'run_client_test.py'
52+
test_script_path.write_text(script_content)
53+
54+
# fmt:off
55+
result = subprocess.run(
56+
[sys.executable, str(test_script_path)],
57+
capture_output=True,
58+
text=True,
59+
timeout=60
60+
)
61+
# fmt:on
62+
63+
assert result.returncode == 0, (
64+
f'Subprocess failed with exit code {result.returncode}. This indicates a crash.\\n'
65+
f'--- STDOUT ---\\n{result.stdout}\\n'
66+
f'--- STDERR ---\\n{result.stderr}'
67+
)
68+
69+
assert 'SUCCESS' in result.stdout, (
70+
"Subprocess completed but did not print the final 'SUCCESS' message.\\n" f"--- STDOUT ---\\n{result.stdout}"
71+
)

0 commit comments

Comments
 (0)