Skip to content

Commit 83343e4

Browse files
Use a shared cache across instances of GeophiresXClient so consumers don't have to manage singletons
1 parent 7b2b78d commit 83343e4

File tree

2 files changed

+49
-40
lines changed

2 files changed

+49
-40
lines changed

src/geophires_x_client/__init__.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
import sys
4-
import threading
4+
from multiprocessing import Manager
55
from pathlib import Path
66

77
# noinspection PyPep8Naming
@@ -14,51 +14,57 @@
1414

1515

1616
class GeophiresXClient:
17+
# Use a multiprocessing Manager to create a cache and lock that are
18+
# shared across all processes spawned by a ProcessPoolExecutor.
19+
_manager = Manager()
20+
_cache = _manager.dict()
21+
_lock = _manager.Lock()
22+
1723
def __init__(self, enable_caching=True, logger_name=None):
1824
if logger_name is None:
1925
logger_name = __name__
2026

2127
self._logger = _get_logger(logger_name=logger_name)
22-
self._enable_caching = enable_caching
23-
self._cache = {}
24-
self._lock = threading.Lock()
28+
self.enable_caching = enable_caching
2529

2630
def get_geophires_result(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
27-
"""
28-
Calculates a GEOPHIRES result in a thread-safe manner.
29-
30-
This method ensures thread safety by using a lock to serialize access,
31-
preventing race conditions on the cache and corruption of global state
32-
(sys.argv, os.cwd) that GEOPHIRES modifies.
33-
"""
34-
with self._lock:
35-
cache_key = hash(input_params)
36-
if self._enable_caching and cache_key in self._cache:
37-
return self._cache[cache_key]
38-
39-
stash_cwd = Path.cwd()
40-
stash_sys_argv = sys.argv
41-
42-
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
43-
try:
44-
geophires.main(enable_geophires_logging_config=False)
45-
except Exception as e:
46-
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
47-
except SystemExit:
48-
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
49-
finally:
50-
# Ensure global state is restored even if geophires.main() fails
51-
sys.argv = stash_sys_argv
52-
os.chdir(stash_cwd)
53-
54-
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
55-
56-
result = GeophiresXResult(input_params.get_output_file_path())
57-
if self._enable_caching:
58-
self._cache[cache_key] = result
59-
31+
if not self.enable_caching:
32+
return self._run_geophires(input_params)
33+
34+
cache_key = hash(input_params)
35+
with self.__class__._lock:
36+
# Use a lock to ensure the check-and-set operation is atomic.
37+
if cache_key in self.__class__._cache:
38+
# Cache hit
39+
return self.__class__._cache[cache_key]
40+
41+
# If not in cache, we will run the simulation, still under the lock,
42+
# to prevent a race condition with other processes.
43+
# Cache miss
44+
result = self._run_geophires(input_params)
45+
self.__class__._cache[cache_key] = result
6046
return result
6147

48+
def _run_geophires(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
49+
"""Helper method to encapsulate the actual GEOPHIRES execution."""
50+
stash_cwd = Path.cwd()
51+
stash_sys_argv = sys.argv
52+
53+
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
54+
try:
55+
geophires.main(enable_geophires_logging_config=False)
56+
except Exception as e:
57+
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
58+
except SystemExit:
59+
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
60+
finally:
61+
# Ensure we always restore the original state
62+
sys.argv = stash_sys_argv
63+
os.chdir(stash_cwd)
64+
65+
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
66+
return GeophiresXResult(input_params.get_output_file_path())
67+
6268

6369
if __name__ == '__main__':
6470
client = GeophiresXClient()

tests/geophires_x_client_tests/test_geophires_client_caching.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ def test_caching_with_identical_immutable_params(self, mock_geophires_main: unit
5656
# The core assertion: was the expensive simulation function only called once?
5757
mock_geophires_main.assert_called_once()
5858

59-
# The results should not only be equivalent but should be the *same object*
60-
# retrieved from the cache on the second call.
61-
self.assertIs(result1, result2, 'The second result should be the cached object instance.')
59+
self.assertDictEqual(result1.result, result2.result)
60+
61+
# TODO The results should probably not only be equivalent but also the *same object*...
62+
# For now they not, but we probably don't care about this since the important part is performance/cache hit -
63+
# manually verified the cache hit in debugger during development.
64+
# self.assertIs(result1, result2, 'The second result should be the cached object instance.')
6265

6366
@patch('geophires_x_client.geophires.main')
6467
def test_no_caching_with_different_immutable_params(self, mock_geophires_main: unittest.mock.MagicMock):

0 commit comments

Comments
 (0)