Skip to content

Commit dfac749

Browse files
WIP - progress on multiprocessing safety
1 parent 955138a commit dfac749

File tree

2 files changed

+208
-43
lines changed

2 files changed

+208
-43
lines changed

src/geophires_x_client/__init__.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import json
22
import os
33
import sys
4+
5+
# --- MULTIPROCESSING CHANGES ---
46
from multiprocessing import Manager
7+
from multiprocessing import RLock
58
from pathlib import Path
69

7-
# noinspection PyPep8Naming
810
from geophires_x import GEOPHIRESv3 as geophires
911

1012
from .common import _get_logger
@@ -14,63 +16,74 @@
1416

1517

1618
class GeophiresXClient:
17-
# Use a multiprocessing Manager to create a cache and lock that are
18-
# shared across all processes spawned by a ProcessPoolExecutor.
19-
_manager = Manager()
20-
_cache = _manager.dict()
21-
_lock = _manager.Lock()
19+
# --- LAZY-LOADED, PROCESS-SAFE SINGLETONS ---
20+
# Define class-level placeholders. These will be shared across all instances.
21+
_manager = None
22+
_cache = None
23+
_lock = RLock() # Use a process-safe re-entrant lock
2224

2325
def __init__(self, enable_caching=True, logger_name=None):
2426
if logger_name is None:
2527
logger_name = __name__
2628

2729
self._logger = _get_logger(logger_name=logger_name)
28-
self.enable_caching = enable_caching
30+
self._enable_caching = enable_caching
31+
32+
# This method will safely initialize the shared manager and cache
33+
# only when the first client instance is created.
34+
self._initialize_shared_resources()
35+
36+
@classmethod
37+
def _initialize_shared_resources(cls):
38+
"""
39+
Initializes the multiprocessing Manager and shared cache dictionary.
40+
This method is designed to be called safely by multiple processes,
41+
ensuring the manager is only started once.
42+
"""
43+
with cls._lock:
44+
if cls._manager is None:
45+
# This code is now protected. It won't run on module import.
46+
# It runs only when the first GeophiresXClient is instantiated.
47+
cls._manager = Manager()
48+
cls._cache = cls._manager.dict()
2949

3050
def get_geophires_result(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
31-
if not self.enable_caching:
32-
return self._run_geophires(input_params)
33-
34-
cache_key = hash(input_params)
35-
with self.__class__._lock:
36-
# Use a lock to ensure the check-and-set operation is atomic.
37-
if cache_key in self.__class__._cache:
38-
# Cache hit
39-
return self.__class__._cache[cache_key]
40-
41-
# If not in cache, we will run the simulation, still under the lock,
42-
# to prevent a race condition with other processes.
43-
# Cache miss
44-
result = self._run_geophires(input_params)
45-
self.__class__._cache[cache_key] = result
46-
return result
47-
48-
def _run_geophires(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
49-
"""Helper method to encapsulate the actual GEOPHIRES execution."""
50-
stash_cwd = Path.cwd()
51-
stash_sys_argv = sys.argv
51+
# Use the class-level lock to protect access to the shared cache
52+
# and the non-reentrant GEOPHIRES core.
53+
with GeophiresXClient._lock:
54+
cache_key = hash(input_params)
55+
if self._enable_caching and cache_key in GeophiresXClient._cache:
56+
return GeophiresXClient._cache[cache_key]
57+
58+
# ... (The rest of your logic remains the same)
59+
stash_cwd = Path.cwd()
60+
stash_sys_argv = sys.argv
61+
62+
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
63+
try:
64+
geophires.main(enable_geophires_logging_config=False)
65+
except Exception as e:
66+
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
67+
except SystemExit:
68+
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
69+
finally:
70+
sys.argv = stash_sys_argv
71+
os.chdir(stash_cwd)
72+
73+
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
74+
75+
result = GeophiresXResult(input_params.get_output_file_path())
76+
if self._enable_caching:
77+
GeophiresXClient._cache[cache_key] = result
5278

53-
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
54-
try:
55-
geophires.main(enable_geophires_logging_config=False)
56-
except Exception as e:
57-
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
58-
except SystemExit:
59-
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
60-
finally:
61-
# Ensure we always restore the original state
62-
sys.argv = stash_sys_argv
63-
os.chdir(stash_cwd)
64-
65-
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
66-
return GeophiresXResult(input_params.get_output_file_path())
79+
return result
6780

6881

6982
if __name__ == '__main__':
83+
# This block is safe, as it's protected from being run on import.
7084
client = GeophiresXClient()
7185
log = _get_logger()
7286

73-
# noinspection PyTypeChecker
7487
params = GeophiresInputParameters(
7588
{
7689
'Print Output to Console': 0,
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import logging
2+
import multiprocessing
3+
import sys
4+
import time
5+
import unittest
6+
from logging.handlers import QueueHandler
7+
from queue import Empty
8+
9+
from geophires_x_client import EndUseOption
10+
from geophires_x_client import GeophiresInputParameters
11+
12+
# Important: We must be able to import the client
13+
from geophires_x_client import GeophiresXClient
14+
15+
16+
# This is the function that each worker process will execute.
17+
# It must be a top-level function to be picklable by multiprocessing.
18+
def run_client_in_process(params_dict: dict, log_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue):
19+
"""
20+
Instantiates a client and runs a calculation, reporting results
21+
and logs back to the main process via queues.
22+
"""
23+
# Configure logging for this worker process to send messages to the shared queue.
24+
root_logger = logging.getLogger()
25+
root_logger.setLevel(logging.INFO)
26+
root_logger.handlers = [QueueHandler(log_queue)]
27+
28+
try:
29+
client = GeophiresXClient(enable_caching=True)
30+
params = GeophiresInputParameters(params_dict)
31+
32+
# This now calls the REAL geophires.main via the client.
33+
result = client.get_geophires_result(params)
34+
35+
# Put the primitive result into the queue to avoid serialization issues.
36+
result_queue.put(result.direct_use_heat_breakeven_price_USD_per_MMBTU)
37+
except Exception as e:
38+
# Report any exceptions back to the main process.
39+
result_queue.put(e)
40+
41+
42+
class TestMultiprocessingSafety(unittest.TestCase):
43+
# Class-level attributes to manage shared resources across test runs.
44+
_ctx = None
45+
_client_for_setup = None
46+
47+
@classmethod
48+
def setUpClass(cls):
49+
"""
50+
Set up the multiprocessing context and start the shared Manager
51+
process ONCE before any tests in this class run.
52+
"""
53+
if sys.platform == 'win32':
54+
# Skip all tests in this class if not on a fork-supporting OS.
55+
raise unittest.SkipTest("The 'fork' multiprocessing context is not available on Windows.")
56+
57+
cls._ctx = multiprocessing.get_context('fork')
58+
# Instantiating the client here creates the shared _manager and _cache
59+
# that all child processes forked from this test will inherit.
60+
cls._client_for_setup = GeophiresXClient()
61+
62+
@classmethod
63+
def tearDownClass(cls):
64+
"""
65+
Shut down the shared Manager process ONCE after all tests in this
66+
class have finished. This is the key to preventing hanging processes.
67+
"""
68+
if cls._client_for_setup and hasattr(cls._client_for_setup, '_manager'):
69+
if cls._client_for_setup._manager is not None:
70+
cls._client_for_setup._manager.shutdown()
71+
72+
def setUp(self):
73+
"""Set up a shared set of parameters for each test."""
74+
# This setup runs before each individual test method.
75+
self.params_dict = {
76+
'Print Output to Console': 0,
77+
'End-Use Option': EndUseOption.DIRECT_USE_HEAT.value,
78+
'Reservoir Model': 1,
79+
'Time steps per year': 1,
80+
# Use nanoseconds to ensure each test run gets a unique cache key (Use a different value per run)
81+
'Reservoir Depth': 4 + time.time_ns() / 1e19,
82+
'Gradient 1': 50,
83+
'Maximum Temperature': 550,
84+
}
85+
86+
def test_client_runs_real_geophires_and_caches_across_processes(self):
87+
"""
88+
Tests that GeophiresXClient can run the real geophires.main in multiple
89+
processes and that the cache is shared between them.
90+
"""
91+
log_queue = self._ctx.Queue()
92+
result_queue = self._ctx.Queue()
93+
num_processes = 4
94+
# Timeout should be long enough for at least one successful run.
95+
process_timeout_seconds = 5
96+
97+
processes = [
98+
self._ctx.Process(target=run_client_in_process, args=(self.params_dict, log_queue, result_queue))
99+
for _ in range(num_processes)
100+
]
101+
102+
for p in processes:
103+
p.start()
104+
105+
# --- Robust Result Collection ---
106+
results = []
107+
for i in range(num_processes):
108+
try:
109+
result = result_queue.get(timeout=process_timeout_seconds)
110+
results.append(result)
111+
except Empty:
112+
# Terminate running processes before failing to avoid hanging the suite
113+
for p_cleanup in processes:
114+
if p_cleanup.is_alive():
115+
p_cleanup.terminate()
116+
self.fail(f'Test timed out waiting for result #{i + 1}. A worker process likely crashed or is stuck.')
117+
118+
# --- Process Cleanup ---
119+
# With the robust tearDownClass, a simple join is sufficient here.
120+
for p in processes:
121+
p.join(timeout=process_timeout_seconds)
122+
123+
# --- Assertions ---
124+
# 1. Check that no process returned an exception.
125+
for r in results:
126+
self.assertNotIsInstance(r, Exception, f'A process failed with an exception: {r}')
127+
128+
# 2. Check that all processes got a valid, non-None result.
129+
for r in results:
130+
self.assertIsNotNone(r)
131+
self.assertIsInstance(r, float)
132+
133+
# 3. CRITICAL: Assert that the expensive GEOPHIRES calculation was only run ONCE.
134+
# This assertion is expected to fail until the caching bug in the client is fixed.
135+
log_records = []
136+
while not log_queue.empty():
137+
log_records.append(log_queue.get().getMessage())
138+
139+
cache_indicator_log = 'GEOPHIRES-X output file:'
140+
successful_runs = sum(1 for record in log_records if cache_indicator_log in record)
141+
142+
self.assertEqual(
143+
successful_runs,
144+
1,
145+
f'FAIL: GEOPHIRES was run {successful_runs} times instead of once, indicating the cross-process cache failed.',
146+
)
147+
148+
print(f'\nDetected {successful_runs} non-cached GEOPHIRES run(s) for {num_processes} requests.')
149+
150+
151+
if __name__ == '__main__':
152+
unittest.main()

0 commit comments

Comments
 (0)