Skip to content

Commit 9f0648f

Browse files
test resource management to not interfere with other tests
1 parent 301ee84 commit 9f0648f

File tree

1 file changed

+80
-94
lines changed

1 file changed

+80
-94
lines changed

tests/geophires_x_client_tests/test_multiprocessing_safety.py

Lines changed: 80 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from geophires_x_client import EndUseOption
1010

11-
# Important: We must be able to import the client
11+
# Important: We must be able to import the client and all parameter classes
1212
from geophires_x_client import GeophiresXClient
1313
from geophires_x_client.geophires_input_parameters import ImmutableGeophiresInputParameters
1414

@@ -26,58 +26,26 @@ def run_client_in_process(params_dict: dict, log_queue: multiprocessing.Queue, r
2626
root_logger.handlers = [QueueHandler(log_queue)]
2727

2828
try:
29+
# Client initialization is now done in the worker, relying on the
30+
# lazy-loading singleton pattern in the client itself.
2931
client = GeophiresXClient(enable_caching=True)
3032
params = ImmutableGeophiresInputParameters(params_dict)
31-
32-
# This now calls the REAL geophires.main via the client.
3333
result = client.get_geophires_result(params)
34-
35-
# Put the primitive result into the queue to avoid serialization issues.
3634
result_queue.put(result.direct_use_heat_breakeven_price_USD_per_MMBTU)
3735
except Exception as e:
38-
# Report any exceptions back to the main process.
3936
result_queue.put(e)
4037

4138

4239
class TestMultiprocessingSafety(unittest.TestCase):
43-
# Class-level attributes to manage shared resources across test runs.
44-
_ctx = None
45-
_client_for_setup = None
46-
47-
@classmethod
48-
def setUpClass(cls):
49-
"""
50-
Set up the multiprocessing context and start the shared Manager
51-
process ONCE before any tests in this class run.
52-
"""
53-
if sys.platform == 'win32':
54-
# Skip all tests in this class if not on a fork-supporting OS.
55-
raise unittest.SkipTest("The 'fork' multiprocessing context is not available on Windows.")
56-
57-
cls._ctx = multiprocessing.get_context('fork')
58-
# Instantiating the client here creates the shared _manager and _cache
59-
# that all child processes forked from this test will inherit.
60-
cls._client_for_setup = GeophiresXClient()
61-
62-
@classmethod
63-
def tearDownClass(cls):
64-
"""
65-
Shut down the shared Manager process ONCE after all tests in this
66-
class have finished. This is the key to preventing hanging processes.
67-
"""
68-
if cls._client_for_setup and hasattr(cls._client_for_setup, '_manager'):
69-
if cls._client_for_setup._manager is not None:
70-
cls._client_for_setup._manager.shutdown()
40+
# By removing setUpClass and tearDownClass, we ensure each test is fully isolated.
7141

7242
def setUp(self):
7343
"""Set up a shared set of parameters for each test."""
74-
# This setup runs before each individual test method.
7544
self.params_dict = {
7645
'Print Output to Console': 0,
7746
'End-Use Option': EndUseOption.DIRECT_USE_HEAT.value,
7847
'Reservoir Model': 1,
7948
'Time steps per year': 1,
80-
# Use nanoseconds to ensure each test run gets a unique cache key (Use a different value per run)
8149
'Reservoir Depth': 4 + time.time_ns() / 1e19,
8250
'Gradient 1': 50,
8351
'Maximum Temperature': 550,
@@ -88,64 +56,82 @@ def test_client_runs_real_geophires_and_caches_across_processes(self):
8856
Tests that GeophiresXClient can run the real geophires.main in multiple
8957
processes and that the cache is shared between them.
9058
"""
91-
log_queue = self._ctx.Queue()
92-
result_queue = self._ctx.Queue()
93-
num_processes = 8
94-
# Timeout should be long enough for at least one successful run.
95-
process_timeout_seconds = 5
96-
97-
processes = [
98-
self._ctx.Process(target=run_client_in_process, args=(self.params_dict, log_queue, result_queue))
99-
for _ in range(num_processes)
100-
]
101-
102-
for p in processes:
103-
p.start()
104-
105-
# --- Robust Result Collection ---
106-
results = []
107-
for i in range(num_processes):
108-
try:
109-
result = result_queue.get(timeout=process_timeout_seconds)
110-
results.append(result)
111-
except Empty:
112-
# Terminate running processes before failing to avoid hanging the suite
113-
for p_cleanup in processes:
114-
if p_cleanup.is_alive():
115-
p_cleanup.terminate()
116-
self.fail(f'Test timed out waiting for result #{i + 1}. A worker process likely crashed or is stuck.')
117-
118-
# --- Process Cleanup ---
119-
# With the robust tearDownClass, a simple join is sufficient here.
120-
for p in processes:
121-
p.join(timeout=process_timeout_seconds)
122-
123-
# --- Assertions ---
124-
# 1. Check that no process returned an exception.
125-
for r in results:
126-
self.assertNotIsInstance(r, Exception, f'A process failed with an exception: {r}')
127-
128-
# 2. Check that all processes got a valid, non-None result.
129-
for r in results:
130-
self.assertIsNotNone(r)
131-
self.assertIsInstance(r, float)
132-
133-
# 3. CRITICAL: Assert that the expensive GEOPHIRES calculation was only run ONCE.
134-
# This assertion is expected to fail until the caching bug in the client is fixed.
135-
log_records = []
136-
while not log_queue.empty():
137-
log_records.append(log_queue.get().getMessage())
138-
139-
cache_indicator_log = 'GEOPHIRES-X output file:'
140-
successful_runs = sum(1 for record in log_records if cache_indicator_log in record)
141-
142-
self.assertEqual(
143-
1,
144-
successful_runs,
145-
f'FAIL: GEOPHIRES was run {successful_runs} times instead of once, indicating the cross-process cache failed.',
146-
)
147-
148-
print(f'\nDetected {successful_runs} non-cached GEOPHIRES run(s) for {num_processes} requests.')
59+
if sys.platform == 'win32':
60+
self.skipTest("The 'fork' multiprocessing context is not available on Windows.")
61+
62+
ctx = multiprocessing.get_context('fork')
63+
# THE FIX: Use the Manager as a context manager within the test.
64+
# This guarantees it and all its resources (queues, etc.) are
65+
# properly created and shut down for each individual test run.
66+
with ctx.Manager() as manager:
67+
log_queue = manager.Queue()
68+
result_queue = manager.Queue()
69+
70+
# The client needs to be re-initialized inside the test to use the new manager.
71+
# This is a bit of a workaround to reset the class-level singleton for the test.
72+
GeophiresXClient._manager = manager
73+
GeophiresXClient._cache = manager.dict()
74+
GeophiresXClient._lock = manager.RLock()
75+
76+
num_processes = 4
77+
process_timeout_seconds = 15
78+
79+
processes = [
80+
ctx.Process(target=run_client_in_process, args=(self.params_dict, log_queue, result_queue))
81+
for _ in range(num_processes)
82+
]
83+
84+
for p in processes:
85+
p.start()
86+
87+
# --- Robust Result Collection ---
88+
results = []
89+
for i in range(num_processes):
90+
try:
91+
result = result_queue.get(timeout=process_timeout_seconds)
92+
results.append(result)
93+
except Empty:
94+
for p_cleanup in processes:
95+
if p_cleanup.is_alive():
96+
p_cleanup.terminate()
97+
self.fail(
98+
f'Test timed out waiting for result #{i + 1}. A worker process likely crashed or is stuck.'
99+
)
100+
101+
# --- Process Cleanup ---
102+
for p in processes:
103+
p.join(timeout=process_timeout_seconds)
104+
if p.is_alive():
105+
p.terminate() # Forcefully end if stuck
106+
self.fail(f'Process {p.pid} failed to terminate cleanly.')
107+
108+
# --- Assertions ---
109+
for r in results:
110+
self.assertNotIsInstance(r, Exception, f'A process failed with an exception: {r}')
111+
self.assertIsNotNone(r)
112+
self.assertIsInstance(r, float)
113+
114+
log_records = []
115+
while not log_queue.empty():
116+
log_records.append(log_queue.get().getMessage())
117+
118+
cache_indicator_log = 'GEOPHIRES-X output file:'
119+
successful_runs = sum(1 for record in log_records if cache_indicator_log in record)
120+
121+
self.assertEqual(
122+
successful_runs,
123+
1,
124+
f'FAIL: GEOPHIRES was run {successful_runs} times instead of once, indicating the cache failed.',
125+
)
126+
127+
print(
128+
f'\nTest passed: Detected {successful_runs} non-cached GEOPHIRES run(s) for {num_processes} requests.'
129+
)
130+
131+
# Reset the client's singleton state after the test to not interfere with others.
132+
GeophiresXClient._manager = None
133+
GeophiresXClient._cache = None
134+
GeophiresXClient._lock = None
149135

150136

151137
if __name__ == '__main__':

0 commit comments

Comments
 (0)