@@ -25,7 +25,7 @@ def get_existent_kwargs_subset(whitelist, kwargs):
2525 return new_kwargs
2626
2727
28- PROCESS_POOL_KWARGS_WHITELIST = ["max_workers" , "mp_context" , " initializer" , "initargs" ]
28+ PROCESS_POOL_KWARGS_WHITELIST = ["max_workers" , "initializer" , "initargs" ]
2929
3030
3131class WrappedProcessPoolExecutor (ProcessPoolExecutor ):
@@ -37,31 +37,28 @@ class WrappedProcessPoolExecutor(ProcessPoolExecutor):
3737 """
3838
3939 def __init__ (self , ** kwargs ):
40- new_kwargs = get_existent_kwargs_subset (PROCESS_POOL_KWARGS_WHITELIST , kwargs )
40+ assert (not "start_method" in kwargs or kwargs ["start_method" ] is None ) or (
41+ not "mp_context" in kwargs
42+ ), "Cannot use both `start_method` and `mp_context` kwargs."
4143
42- self .did_overwrite_start_method = False
43- if kwargs .get ("start_method" , None ) is not None :
44- self .did_overwrite_start_method = True
45- self .old_start_method = multiprocessing .get_start_method ()
46- start_method = kwargs ["start_method" ]
47- logging .info (
48- f"Overwriting start_method to { start_method } . Previous value: { self .old_start_method } "
49- )
50- multiprocessing .set_start_method (start_method , force = True )
44+ new_kwargs = get_existent_kwargs_subset (PROCESS_POOL_KWARGS_WHITELIST , kwargs )
5145
52- ProcessPoolExecutor . __init__ ( self , ** new_kwargs )
46+ mp_context = None
5347
54- def shutdown (self , * args , ** kwargs ):
48+ if "mp_context" in kwargs :
49+ mp_context = kwargs ["mp_context" ]
50+ elif "start_method" in kwargs and kwargs ["start_method" ] is not None :
51+ mp_context = multiprocessing .get_context (kwargs ["start_method" ])
52+ elif "MULTIPROCESSING_DEFAULT_START_METHOD" in os .environ :
53+ mp_context = multiprocessing .get_context (
54+ os .environ ["MULTIPROCESSING_DEFAULT_START_METHOD" ]
55+ )
56+ else :
57+ mp_context = multiprocessing .get_context ("spawn" )
5558
56- super (). shutdown ( * args , ** kwargs )
59+ new_kwargs [ "mp_context" ] = mp_context
5760
58- if self .did_overwrite_start_method :
59- logging .info (
60- f"Restoring start_method to original value: { self .old_start_method } ."
61- )
62- multiprocessing .set_start_method (self .old_start_method , force = True )
63- self .old_start_method = None
64- self .did_overwrite_start_method = False
61+ ProcessPoolExecutor .__init__ (self , ** new_kwargs )
6562
6663 def submit (self , * args , ** kwargs ):
6764
@@ -88,7 +85,7 @@ def submit(self, *args, **kwargs):
8885 # where wrapper_fn_1 is called, which eventually calls wrapper_fn_2, which eventually calls actual_fn.
8986 call_stack = []
9087
91- if multiprocessing .get_start_method () != "fork" :
88+ if self . _mp_context .get_start_method () != "fork" :
9289 # If a start_method other than the default "fork" is used, logging needs to be re-setup,
9390 # because the programming context is not inherited in those cases.
9491 multiprocessing_logging_setup_fn = get_multiprocessing_logging_setup_fn ()
0 commit comments