Skip to content

Commit 301ee84

Browse files
Use ImmutableGeophiresInputParameters in TestMultiprocessingSafety
1 parent dfac749 commit 301ee84

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

src/geophires_x_client/__init__.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import json
22
import os
33
import sys
4-
5-
# --- MULTIPROCESSING CHANGES ---
4+
import threading
65
from multiprocessing import Manager
7-
from multiprocessing import RLock
86
from pathlib import Path
97

108
from geophires_x import GEOPHIRESv3 as geophires
@@ -16,11 +14,14 @@
1614

1715

1816
class GeophiresXClient:
19-
# --- LAZY-LOADED, PROCESS-SAFE SINGLETONS ---
20-
# Define class-level placeholders. These will be shared across all instances.
17+
# --- Class-level shared resources ---
18+
# These will be initialized lazily and shared across all instances and processes.
2119
_manager = None
2220
_cache = None
23-
_lock = RLock() # Use a process-safe re-entrant lock
21+
_lock = None # This will be a process-safe RLock from the manager.
22+
23+
# A standard threading lock to make the one-time initialization thread-safe.
24+
_init_lock = threading.Lock()
2425

2526
def __init__(self, enable_caching=True, logger_name=None):
2627
if logger_name is None:
@@ -29,33 +30,39 @@ def __init__(self, enable_caching=True, logger_name=None):
2930
self._logger = _get_logger(logger_name=logger_name)
3031
self._enable_caching = enable_caching
3132

32-
# This method will safely initialize the shared manager and cache
33-
# only when the first client instance is created.
34-
self._initialize_shared_resources()
33+
# Lazy-initialize shared resources if they haven't been already.
34+
# This approach is safe to call from multiple threads/processes.
35+
if GeophiresXClient._manager is None:
36+
self._initialize_shared_resources()
3537

3638
@classmethod
3739
def _initialize_shared_resources(cls):
3840
"""
39-
Initializes the multiprocessing Manager and shared cache dictionary.
40-
This method is designed to be called safely by multiple processes,
41-
ensuring the manager is only started once.
41+
Initializes the multiprocessing Manager and shared resources (cache, lock)
42+
in a thread-safe and process-safe manner.
4243
"""
43-
with cls._lock:
44+
# Use a thread-safe lock to ensure this block only ever runs once
45+
# across all threads in the main process.
46+
with cls._init_lock:
47+
# The double-check locking pattern ensures we don't try to
48+
# re-initialize if another thread finished while we were waiting.
4449
if cls._manager is None:
45-
# This code is now protected. It won't run on module import.
46-
# It runs only when the first GeophiresXClient is instantiated.
4750
cls._manager = Manager()
4851
cls._cache = cls._manager.dict()
52+
cls._lock = cls._manager.RLock() # The Manager now creates the lock.
4953

5054
def get_geophires_result(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
51-
# Use the class-level lock to protect access to the shared cache
52-
# and the non-reentrant GEOPHIRES core.
55+
"""
56+
Calculates a GEOPHIRES result in a thread-safe and process-safe manner.
57+
"""
58+
# Use the process-safe lock from the manager to make the check-then-act
59+
# operation on the cache fully atomic across multiple processes.
5360
with GeophiresXClient._lock:
5461
cache_key = hash(input_params)
5562
if self._enable_caching and cache_key in GeophiresXClient._cache:
5663
return GeophiresXClient._cache[cache_key]
5764

58-
# ... (The rest of your logic remains the same)
65+
# --- This section is now guaranteed to run only once per unique input ---
5966
stash_cwd = Path.cwd()
6067
stash_sys_argv = sys.argv
6168

@@ -67,20 +74,21 @@ def get_geophires_result(self, input_params: GeophiresInputParameters) -> Geophi
6774
except SystemExit:
6875
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
6976
finally:
77+
# Ensure global state is restored even if geophires.main() fails
7078
sys.argv = stash_sys_argv
7179
os.chdir(stash_cwd)
7280

7381
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
7482

7583
result = GeophiresXResult(input_params.get_output_file_path())
7684
if self._enable_caching:
77-
GeophiresXClient._cache[cache_key] = result
85+
self._cache[cache_key] = result
7886

7987
return result
8088

8189

8290
if __name__ == '__main__':
83-
# This block is safe, as it's protected from being run on import.
91+
# This block remains for direct testing of the script.
8492
client = GeophiresXClient()
8593
log = _get_logger()
8694

tests/geophires_x_client_tests/test_multiprocessing_safety.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from queue import Empty
88

99
from geophires_x_client import EndUseOption
10-
from geophires_x_client import GeophiresInputParameters
1110

1211
# Important: We must be able to import the client
1312
from geophires_x_client import GeophiresXClient
13+
from geophires_x_client.geophires_input_parameters import ImmutableGeophiresInputParameters
1414

1515

1616
# This is the function that each worker process will execute.
@@ -27,7 +27,7 @@ def run_client_in_process(params_dict: dict, log_queue: multiprocessing.Queue, r
2727

2828
try:
2929
client = GeophiresXClient(enable_caching=True)
30-
params = GeophiresInputParameters(params_dict)
30+
params = ImmutableGeophiresInputParameters(params_dict)
3131

3232
# This now calls the REAL geophires.main via the client.
3333
result = client.get_geophires_result(params)
@@ -90,7 +90,7 @@ def test_client_runs_real_geophires_and_caches_across_processes(self):
9090
"""
9191
log_queue = self._ctx.Queue()
9292
result_queue = self._ctx.Queue()
93-
num_processes = 4
93+
num_processes = 8
9494
# Timeout should be long enough for at least one successful run.
9595
process_timeout_seconds = 5
9696

@@ -140,8 +140,8 @@ def test_client_runs_real_geophires_and_caches_across_processes(self):
140140
successful_runs = sum(1 for record in log_records if cache_indicator_log in record)
141141

142142
self.assertEqual(
143-
successful_runs,
144143
1,
144+
successful_runs,
145145
f'FAIL: GEOPHIRES was run {successful_runs} times instead of once, indicating the cross-process cache failed.',
146146
)
147147

0 commit comments

Comments
 (0)