Skip to content

Commit 247182e

Browse files
test_multiprocessing_safety.py passes
1 parent 9f0648f commit 247182e

File tree

2 files changed

+71
-79
lines changed

2 files changed

+71
-79
lines changed

src/geophires_x_client/__init__.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
import sys
43
import threading
@@ -7,18 +6,23 @@
76

87
from geophires_x import GEOPHIRESv3 as geophires
98

9+
# Assuming these are in a sibling file or accessible path
1010
from .common import _get_logger
11-
from .geophires_input_parameters import EndUseOption
1211
from .geophires_input_parameters import GeophiresInputParameters
12+
from .geophires_input_parameters import ImmutableGeophiresInputParameters
1313
from .geophires_x_result import GeophiresXResult
1414

1515

1616
class GeophiresXClient:
17+
"""
18+
A thread-safe and process-safe client for running GEOPHIRES simulations.
19+
Relies on an explicit shutdown() call to clean up background processes.
20+
"""
21+
1722
# --- Class-level shared resources ---
18-
# These will be initialized lazily and shared across all instances and processes.
1923
_manager = None
2024
_cache = None
21-
_lock = None # This will be a process-safe RLock from the manager.
25+
_lock = None
2226

2327
# A standard threading lock to make the one-time initialization thread-safe.
2428
_init_lock = threading.Lock()
@@ -31,7 +35,6 @@ def __init__(self, enable_caching=True, logger_name=None):
3135
self._enable_caching = enable_caching
3236

3337
# Lazy-initialize shared resources if they haven't been already.
34-
# This approach is safe to call from multiple threads/processes.
3538
if GeophiresXClient._manager is None:
3639
self._initialize_shared_resources()
3740

@@ -41,69 +44,63 @@ def _initialize_shared_resources(cls):
4144
Initializes the multiprocessing Manager and shared resources (cache, lock)
4245
in a thread-safe and process-safe manner.
4346
"""
44-
# Use a thread-safe lock to ensure this block only ever runs once
45-
# across all threads in the main process.
4647
with cls._init_lock:
47-
# The double-check locking pattern ensures we don't try to
48-
# re-initialize if another thread finished while we were waiting.
4948
if cls._manager is None:
5049
cls._manager = Manager()
5150
cls._cache = cls._manager.dict()
52-
cls._lock = cls._manager.RLock() # The Manager now creates the lock.
51+
cls._lock = cls._manager.RLock()
5352

54-
def get_geophires_result(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
53+
@classmethod
54+
def shutdown(cls):
5555
"""
56-
Calculates a GEOPHIRES result in a thread-safe and process-safe manner.
56+
Explicitly shuts down the background manager process.
57+
This MUST be called when the application is finished with the client
58+
to prevent orphaned processes.
5759
"""
58-
# Use the process-safe lock from the manager to make the check-then-act
59-
# operation on the cache fully atomic across multiple processes.
60-
with GeophiresXClient._lock:
61-
cache_key = hash(input_params)
62-
if self._enable_caching and cache_key in GeophiresXClient._cache:
63-
return GeophiresXClient._cache[cache_key]
60+
with cls._init_lock:
61+
if cls._manager is not None:
62+
cls._manager.shutdown()
63+
cls._manager = None
64+
cls._cache = None
65+
cls._lock = None
6466

65-
# --- This section is now guaranteed to run only once per unique input ---
66-
stash_cwd = Path.cwd()
67-
stash_sys_argv = sys.argv
67+
def get_geophires_result(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
68+
"""
69+
Calculates a GEOPHIRES result, using a cross-process cache to avoid
70+
re-computing results for the same inputs. Caching is only effective
71+
when providing an instance of ImmutableGeophiresInputParameters.
72+
"""
73+
is_immutable = isinstance(input_params, ImmutableGeophiresInputParameters)
6874

69-
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
70-
try:
71-
geophires.main(enable_geophires_logging_config=False)
72-
except Exception as e:
73-
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
74-
except SystemExit:
75-
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
76-
finally:
77-
# Ensure global state is restored even if geophires.main() fails
78-
sys.argv = stash_sys_argv
79-
os.chdir(stash_cwd)
75+
if not (self._enable_caching and is_immutable):
76+
return self._run_simulation(input_params)
8077

81-
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
78+
cache_key = hash(input_params)
8279

83-
result = GeophiresXResult(input_params.get_output_file_path())
84-
if self._enable_caching:
85-
self._cache[cache_key] = result
80+
with GeophiresXClient._lock:
81+
if cache_key in GeophiresXClient._cache:
82+
return GeophiresXClient._cache[cache_key]
8683

84+
result = self._run_simulation(input_params)
85+
GeophiresXClient._cache[cache_key] = result
8786
return result
8887

89-
90-
if __name__ == '__main__':
91-
# This block remains for direct testing of the script.
92-
client = GeophiresXClient()
93-
log = _get_logger()
94-
95-
params = GeophiresInputParameters(
96-
{
97-
'Print Output to Console': 0,
98-
'End-Use Option': EndUseOption.DIRECT_USE_HEAT.value,
99-
'Reservoir Model': 1,
100-
'Time steps per year': 1,
101-
'Reservoir Depth': 3,
102-
'Gradient 1': 50,
103-
'Maximum Temperature': 250,
104-
}
105-
)
106-
107-
result_ = client.get_geophires_result(params)
108-
log.info(f'Breakeven price: ${result_.direct_use_heat_breakeven_price_USD_per_MMBTU}/MMBTU')
109-
log.info(json.dumps(result_.result, indent=2))
88+
def _run_simulation(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
89+
"""Helper method to encapsulate the actual GEOPHIRES run."""
90+
stash_cwd = Path.cwd()
91+
stash_sys_argv = sys.argv
92+
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
93+
94+
try:
95+
geophires.main(enable_geophires_logging_config=False)
96+
except Exception as e:
97+
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
98+
except SystemExit:
99+
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
100+
finally:
101+
sys.argv = stash_sys_argv
102+
os.chdir(stash_cwd)
103+
104+
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
105+
result = GeophiresXResult(input_params.get_output_file_path())
106+
return result

tests/geophires_x_client_tests/test_multiprocessing_safety.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,24 @@
66
from logging.handlers import QueueHandler
77
from queue import Empty
88

9-
from geophires_x_client import EndUseOption
10-
119
# Important: We must be able to import the client and all parameter classes
1210
from geophires_x_client import GeophiresXClient
11+
from geophires_x_client.geophires_input_parameters import EndUseOption
1312
from geophires_x_client.geophires_input_parameters import ImmutableGeophiresInputParameters
1413

1514

16-
# This is the function that each worker process will execute.
17-
# It must be a top-level function to be picklable by multiprocessing.
1815
def run_client_in_process(params_dict: dict, log_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue):
1916
"""
20-
Instantiates a client and runs a calculation, reporting results
21-
and logs back to the main process via queues.
17+
This is the function that each worker process will execute.
18+
It must be a top-level function to be picklable by multiprocessing.
2219
"""
2320
# Configure logging for this worker process to send messages to the shared queue.
2421
root_logger = logging.getLogger()
2522
root_logger.setLevel(logging.INFO)
2623
root_logger.handlers = [QueueHandler(log_queue)]
2724

2825
try:
29-
# Client initialization is now done in the worker, relying on the
30-
# lazy-loading singleton pattern in the client itself.
26+
# The client will use the Manager that was injected by the test's main process.
3127
client = GeophiresXClient(enable_caching=True)
3228
params = ImmutableGeophiresInputParameters(params_dict)
3329
result = client.get_geophires_result(params)
@@ -37,10 +33,8 @@ def run_client_in_process(params_dict: dict, log_queue: multiprocessing.Queue, r
3733

3834

3935
class TestMultiprocessingSafety(unittest.TestCase):
40-
# By removing setUpClass and tearDownClass, we ensure each test is fully isolated.
41-
4236
def setUp(self):
43-
"""Set up a shared set of parameters for each test."""
37+
"""Set up a unique set of parameters for each test."""
4438
self.params_dict = {
4539
'Print Output to Console': 0,
4640
'End-Use Option': EndUseOption.DIRECT_USE_HEAT.value,
@@ -54,25 +48,26 @@ def setUp(self):
5448
def test_client_runs_real_geophires_and_caches_across_processes(self):
5549
"""
5650
Tests that GeophiresXClient can run the real geophires.main in multiple
57-
processes and that the cache is shared between them.
51+
processes and that the cache is shared between them. This test is now
52+
fully self-contained to prevent resource conflicts with the test runner.
5853
"""
5954
if sys.platform == 'win32':
6055
self.skipTest("The 'fork' multiprocessing context is not available on Windows.")
6156

6257
ctx = multiprocessing.get_context('fork')
63-
# THE FIX: Use the Manager as a context manager within the test.
64-
# This guarantees it and all its resources (queues, etc.) are
65-
# properly created and shut down for each individual test run.
58+
# Use the Manager as a context manager. This is the key to ensuring
59+
# all resources it creates (queues, etc.) are properly shut down
60+
# at the end of the block, preventing deadlocks.
6661
with ctx.Manager() as manager:
67-
log_queue = manager.Queue()
68-
result_queue = manager.Queue()
69-
70-
# The client needs to be re-initialized inside the test to use the new manager.
71-
# This is a bit of a workaround to reset the class-level singleton for the test.
62+
# For this test to work, we MUST inject the test-specific manager
63+
# into the client's class-level singleton attributes.
7264
GeophiresXClient._manager = manager
7365
GeophiresXClient._cache = manager.dict()
7466
GeophiresXClient._lock = manager.RLock()
7567

68+
log_queue = manager.Queue()
69+
result_queue = manager.Queue()
70+
7671
num_processes = 4
7772
process_timeout_seconds = 15
7873

@@ -102,7 +97,7 @@ def test_client_runs_real_geophires_and_caches_across_processes(self):
10297
for p in processes:
10398
p.join(timeout=process_timeout_seconds)
10499
if p.is_alive():
105-
p.terminate() # Forcefully end if stuck
100+
p.terminate()
106101
self.fail(f'Process {p.pid} failed to terminate cleanly.')
107102

108103
# --- Assertions ---
@@ -128,7 +123,7 @@ def test_client_runs_real_geophires_and_caches_across_processes(self):
128123
f'\nTest passed: Detected {successful_runs} non-cached GEOPHIRES run(s) for {num_processes} requests.'
129124
)
130125

131-
# Reset the client's singleton state after the test to not interfere with others.
126+
# CRITICAL: Reset the client's singleton state after the test to not interfere with other tests.
132127
GeophiresXClient._manager = None
133128
GeophiresXClient._cache = None
134129
GeophiresXClient._lock = None

0 commit comments

Comments
 (0)