Skip to content

Commit 7b2b78d

Browse files
Make GeophiresXClient cache thread-safe
1 parent 863bbb6 commit 7b2b78d

File tree

1 file changed

+40
-29
lines changed

1 file changed

+40
-29
lines changed

src/geophires_x_client/__init__.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
import os
33
import sys
4+
import threading
45
from pathlib import Path
56

7+
# noinspection PyPep8Naming
68
from geophires_x import GEOPHIRESv3 as geophires
79

810
from .common import _get_logger
@@ -19,34 +21,43 @@ def __init__(self, enable_caching=True, logger_name=None):
1921
self._logger = _get_logger(logger_name=logger_name)
2022
self._enable_caching = enable_caching
2123
self._cache = {}
24+
self._lock = threading.Lock()
2225

2326
def get_geophires_result(self, input_params: GeophiresInputParameters) -> GeophiresXResult:
24-
cache_key = hash(input_params)
25-
if self._enable_caching and cache_key in self._cache:
26-
return self._cache[cache_key]
27-
28-
stash_cwd = Path.cwd()
29-
stash_sys_argv = sys.argv
30-
31-
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
32-
try:
33-
geophires.main(enable_geophires_logging_config=False)
34-
except Exception as e:
35-
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
36-
except SystemExit:
37-
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
38-
39-
# Undo Geophires internal global settings changes
40-
sys.argv = stash_sys_argv
41-
os.chdir(stash_cwd)
42-
43-
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
44-
45-
result = GeophiresXResult(input_params.get_output_file_path())
46-
if self._enable_caching:
47-
self._cache[cache_key] = result
48-
49-
return result
27+
"""
28+
Calculates a GEOPHIRES result in a thread-safe manner.
29+
30+
This method ensures thread safety by using a lock to serialize access,
31+
preventing race conditions on the cache and corruption of global state
32+
(sys.argv, os.cwd) that GEOPHIRES modifies.
33+
"""
34+
with self._lock:
35+
cache_key = hash(input_params)
36+
if self._enable_caching and cache_key in self._cache:
37+
return self._cache[cache_key]
38+
39+
stash_cwd = Path.cwd()
40+
stash_sys_argv = sys.argv
41+
42+
sys.argv = ['', input_params.as_file_path(), input_params.get_output_file_path()]
43+
try:
44+
geophires.main(enable_geophires_logging_config=False)
45+
except Exception as e:
46+
raise RuntimeError(f'GEOPHIRES encountered an exception: {e!s}') from e
47+
except SystemExit:
48+
raise RuntimeError('GEOPHIRES exited without giving a reason') from None
49+
finally:
50+
# Ensure global state is restored even if geophires.main() fails
51+
sys.argv = stash_sys_argv
52+
os.chdir(stash_cwd)
53+
54+
self._logger.info(f'GEOPHIRES-X output file: {input_params.get_output_file_path()}')
55+
56+
result = GeophiresXResult(input_params.get_output_file_path())
57+
if self._enable_caching:
58+
self._cache[cache_key] = result
59+
60+
return result
5061

5162

5263
if __name__ == '__main__':
@@ -66,6 +77,6 @@ def get_geophires_result(self, input_params: GeophiresInputParameters) -> Geophi
6677
}
6778
)
6879

69-
result = client.get_geophires_result(params)
70-
log.info(f'Breakeven price: ${result.direct_use_heat_breakeven_price_USD_per_MMBTU}/MMBTU')
71-
log.info(json.dumps(result.result, indent=2))
80+
result_ = client.get_geophires_result(params)
81+
log.info(f'Breakeven price: ${result_.direct_use_heat_breakeven_price_USD_per_MMBTU}/MMBTU')
82+
log.info(json.dumps(result_.result, indent=2))

0 commit comments

Comments
 (0)