1
1
import atexit
2
- import os
3
2
import sys
4
3
import threading
5
4
from multiprocessing import Manager
6
- from pathlib import Path
5
+ from multiprocessing import current_process
7
6
8
7
# noinspection PyPep8Naming
9
8
from geophires_x import GEOPHIRESv3 as geophires
@@ -36,24 +35,30 @@ def __init__(self, enable_caching=True, logger_name=None):
36
35
self ._logger = _get_logger (logger_name = logger_name )
37
36
self ._enable_caching = enable_caching
38
37
39
- # Lazy-initialize shared resources if they haven't been already.
40
38
if enable_caching and GeophiresXClient ._manager is None :
39
+ # Lazy-initialize shared resources if they haven't been already.
41
40
self ._initialize_shared_resources ()
42
41
43
42
@classmethod
44
43
def _initialize_shared_resources (cls ):
45
44
"""
46
45
Initializes the multiprocessing Manager and shared resources in a
47
- thread-safe manner. It also registers the shutdown hook to ensure
48
- automatic cleanup on application exit.
46
+ thread-safe and now process-safe manner. It also registers the
47
+ shutdown hook to ensure automatic cleanup on application exit.
49
48
"""
50
- with cls ._init_lock :
51
- if cls ._manager is None :
52
- cls ._manager = Manager ()
53
- cls ._cache = cls ._manager .dict ()
54
- cls ._lock = cls ._manager .RLock ()
55
- # Register the shutdown method to be called automatically on exit.
56
- atexit .register (cls .shutdown )
49
+ # Ensure that only the top-level user process can create the manager.
50
+ # A spawned child process, which re-imports this script, will have a different name
51
+ # (e.g., 'Spawn-1') and will skip this entire block, preventing a recursive crash.
52
+ if current_process ().name == 'MainProcess' :
53
+ with cls ._init_lock :
54
+ if cls ._manager is None :
55
+ cls ._logger = _get_logger (__name__ ) # Add a logger for this class method
56
+ cls ._logger .debug ('MainProcess is creating the shared multiprocessing manager...' )
57
+ cls ._manager = Manager ()
58
+ cls ._cache = cls ._manager .dict ()
59
+ cls ._lock = cls ._manager .RLock ()
60
+ # Register the shutdown method to be called automatically on exit.
61
+ atexit .register (cls .shutdown )
57
62
58
63
@classmethod
59
64
def shutdown (cls ):
@@ -65,9 +70,17 @@ def shutdown(cls):
65
70
"""
66
71
with cls ._init_lock :
67
72
if cls ._manager is not None :
73
+ cls ._logger = _get_logger (__name__ )
74
+ cls ._logger .debug ('Shutting down the shared multiprocessing manager...' )
68
75
cls ._manager .shutdown ()
69
76
# De-register the hook to avoid trying to shut down twice.
70
- atexit .unregister (cls .shutdown )
77
+ try :
78
+ atexit .unregister (cls .shutdown )
79
+ except Exception as e :
80
+ # Fails in some environments (e.g. pytest), but is not critical
81
+ cls ._logger .debug (
82
+ f'Encountered exception shutting down the shared multiprocessing manager (OK): ' f'{ e !s} '
83
+ )
71
84
cls ._manager = None
72
85
cls ._cache = None
73
86
cls ._lock = None
@@ -80,22 +93,23 @@ def get_geophires_result(self, input_params: GeophiresInputParameters) -> Geophi
80
93
"""
81
94
is_immutable = isinstance (input_params , ImmutableGeophiresInputParameters )
82
95
83
- if not (self ._enable_caching and is_immutable ):
96
+ if not (self ._enable_caching and is_immutable and GeophiresXClient . _manager is not None ):
84
97
return self ._run_simulation (input_params )
85
98
86
99
cache_key = hash (input_params )
87
100
88
101
with GeophiresXClient ._lock :
89
102
if cache_key in GeophiresXClient ._cache :
103
+ # self._logger.debug(f'Cache hit for inputs: {input_params}')
90
104
return GeophiresXClient ._cache [cache_key ]
91
105
106
+ # Cache miss
92
107
result = self ._run_simulation (input_params )
93
108
GeophiresXClient ._cache [cache_key ] = result
94
109
return result
95
110
96
111
def _run_simulation (self , input_params : GeophiresInputParameters ) -> GeophiresXResult :
97
112
"""Helper method to encapsulate the actual GEOPHIRES run."""
98
- stash_cwd = Path .cwd ()
99
113
stash_sys_argv = sys .argv
100
114
sys .argv = ['' , input_params .as_file_path (), input_params .get_output_file_path ()]
101
115
@@ -107,7 +121,6 @@ def _run_simulation(self, input_params: GeophiresInputParameters) -> GeophiresXR
107
121
raise RuntimeError ('GEOPHIRES exited without giving a reason' ) from None
108
122
finally :
109
123
sys .argv = stash_sys_argv
110
- os .chdir (stash_cwd )
111
124
112
125
self ._logger .info (f'GEOPHIRES-X output file: { input_params .get_output_file_path ()} ' )
113
126
result = GeophiresXResult (input_params .get_output_file_path ())
0 commit comments