99import sys
1010import threading
1111import traceback
12+ import typing
1213from concurrent .futures import Future , ProcessPoolExecutor
1314from concurrent .futures .process import BrokenProcessPool
14- from typing import Any , BinaryIO , Callable , Dict , Tuple , TypeVar
15+ from enum import Enum
16+ from typing import Any , BinaryIO , Callable , Dict , Optional , Tuple , TypeVar
1517from typing_extensions import Never , ParamSpec
1618
1719# _thread_safe_fork is needed because the subprocesses in the pool can read
@@ -88,14 +90,39 @@ def __init__(self, details: str) -> None:
8890 super ().__init__ (f"An exception occurred in a subprocess:\n \n { details } " )
8991
9092
93+ class SubprocPickler :
94+ """
95+ Allows a caller to provide a custom pickler for passing data with the
96+ subprocess.
97+ """
98+
99+ def dumps (self , obj : object ) -> bytes :
100+ return pickle .dumps (obj , pickle .HIGHEST_PROTOCOL )
101+
102+ def loads (self , data : bytes ) -> object :
103+ return pickle .loads (data )
104+
105+
106+ class SubprocKind (Enum ):
107+ FORK = "fork"
108+ SPAWN = "spawn"
109+
110+
91111class SubprocPool :
92112 """
93113 Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in
94114 a subprocess.Popen() to try to avoid issues with forking/spawning
95115 """
96116
97- def __init__ (self , nprocs : int ) -> None :
117+ def __init__ (
118+ self ,
119+ nprocs : int ,
120+ pickler : Optional [SubprocPickler ] = None ,
121+ kind : SubprocKind = SubprocKind .FORK ,
122+ ) -> None :
98123 entry = os .path .join (os .path .dirname (__file__ ), "__main__.py" )
124+ self .pickler = pickler or SubprocPickler ()
125+ self .kind = kind
99126
100127 subproc_read_fd , write_fd = os .pipe ()
101128 read_fd , subproc_write_fd = os .pipe ()
@@ -105,6 +132,8 @@ def __init__(self, nprocs: int) -> None:
105132 cmd = [
106133 sys .executable ,
107134 entry ,
135+ f"--pickler={ self .pickler .__class__ .__module__ } .{ self .pickler .__class__ .__name__ } " ,
136+ f"--kind={ self .kind .value } " ,
108137 f"--workers={ nprocs } " ,
109138 f"--parent={ os .getpid ()} " ,
110139 f"--read-fd={ str (subproc_read_fd )} " ,
@@ -143,7 +172,7 @@ def submit(
143172 ) -> Future [_T ]:
144173 if args or kwargs :
145174 job_fn = functools .partial (job_fn , * args , ** kwargs )
146- job_data = pickle . dumps (job_fn , pickle . HIGHEST_PROTOCOL )
175+ job_data = self . pickler . dumps (job_fn )
147176 future : Future [_T ]
148177 with self .futures_lock :
149178 job_id = next (self .job_id_count )
@@ -156,31 +185,48 @@ def submit(
156185 return future
157186
158187 def _read_thread (self ) -> None :
159- try :
160- while True :
188+ while True :
189+ data = b""
190+ try :
161191 job_id , data = _recv_msg (self .read_pipe )
162- if job_id < 0 :
163- if self .running :
164- log .warning ("SubprocPool unclean exit" )
165- self .read_pipe .close ()
192+ except Exception as e :
193+ # Something went wrong during the read. There's no way we have a
194+ # valid job_id.
195+ log .exception ("failure in subproc_pool._recv_msg" )
196+ job_id = - 1
197+
198+ if job_id < 0 :
199+ # read_pipe returned None or got exception
200+ if self .running :
201+ log .warning ("SubprocPool unclean exit" )
202+ self .running = False
203+ self .read_pipe .close ()
204+ # Cancel all the pending futures.
205+ self .shutdown ()
206+ return
207+
208+ try :
209+ result = self .pickler .loads (data )
210+ except Exception as e :
211+ # Something went wrong unpickling. We have a job_id so just
212+ # notify that particular future and continue on.
213+ log .exception ("unpickle failure in SubprocPool._read_thread" )
214+ result = e
215+
216+ with self .futures_lock :
217+ if not self .running :
166218 return
167- result = pickle .loads (data )
168- with self .futures_lock :
169- if not self .running :
170- return
171- if isinstance (result , _SubprocExceptionInfo ):
172- # An exception occurred in the submitted job
173- self .pending_futures [job_id ].set_exception (
174- SubprocException (result .details )
175- )
176- elif isinstance (result , Exception ):
177- # An exception occurred in some of our subprocess machinery.
178- self .pending_futures [job_id ].set_exception (result )
179- else :
180- self .pending_futures [job_id ].set_result (result )
181- del self .pending_futures [job_id ]
182- except Exception :
183- log .exception ("failure in SubprocPool._read_thread" )
219+ if isinstance (result , _SubprocExceptionInfo ):
220+ # An exception occurred in the submitted job
221+ self .pending_futures [job_id ].set_exception (
222+ SubprocException (result .details )
223+ )
224+ elif isinstance (result , Exception ):
225+ # An exception occurred in some of our subprocess machinery.
226+ self .pending_futures [job_id ].set_exception (result )
227+ else :
228+ self .pending_futures [job_id ].set_result (result )
229+ del self .pending_futures [job_id ]
184230
185231 def shutdown (self ) -> None :
186232 try :
@@ -204,7 +250,16 @@ def shutdown(self) -> None:
204250class SubprocMain :
205251 """Communicates with a SubprocPool in the parent process, called by __main__.py"""
206252
207- def __init__ (self , nprocs : int , read_pipe : BinaryIO , write_pipe : BinaryIO ) -> None :
253+ def __init__ (
254+ self ,
255+ pickler : SubprocPickler ,
256+ kind : SubprocKind ,
257+ nprocs : int ,
258+ read_pipe : BinaryIO ,
259+ write_pipe : BinaryIO ,
260+ ) -> None :
261+ self .pickler = pickler
262+ self .kind = kind
208263 self .read_pipe = read_pipe
209264 self .write_pipe = write_pipe
210265 self .write_lock = threading .Lock ()
@@ -215,7 +270,7 @@ def __init__(self, nprocs: int, read_pipe: BinaryIO, write_pipe: BinaryIO) -> No
215270 def _new_pool (self , nprocs : int , warm : bool ) -> ProcessPoolExecutor :
216271 pool = ProcessPoolExecutor (
217272 nprocs ,
218- mp_context = multiprocessing .get_context ("fork" ),
273+ mp_context = multiprocessing .get_context (self . kind . value ),
219274 initializer = functools .partial (_async_compile_initializer , os .getpid ()),
220275 )
221276 multiprocessing .util .Finalize (None , pool .shutdown , exitpriority = sys .maxsize )
@@ -253,7 +308,9 @@ def submit(self, job_id: int, data: bytes) -> None:
253308 self .pool = self ._new_pool (self .nprocs , False )
254309
255310 def _submit_inner (self , job_id : int , data : bytes ) -> None :
256- future = self .pool .submit (functools .partial (SubprocMain .do_job , data ))
311+ future = self .pool .submit (
312+ functools .partial (SubprocMain .do_job , self .pickler , data )
313+ )
257314
258315 def callback (_ : Future [Any ]) -> None :
259316 if not self .running :
@@ -262,7 +319,7 @@ def callback(_: Future[Any]) -> None:
262319 result = future .result ()
263320 except Exception as e :
264321 log .exception ("Error in subprocess" )
265- result = pickle . dumps (e , pickle . HIGHEST_PROTOCOL )
322+ result = self . pickler . dumps (e )
266323 assert isinstance (result , bytes )
267324 with self .write_lock :
268325 if self .running :
@@ -272,14 +329,15 @@ def callback(_: Future[Any]) -> None:
272329 future .add_done_callback (callback )
273330
274331 @staticmethod
275- def do_job (data : bytes ) -> bytes :
332+ def do_job (pickler : SubprocPickler , data : bytes ) -> bytes :
276333 # do the pickle/unpickle in the sub-subproc
277- job = pickle .loads (data )
334+ job = typing .cast (Callable [[], object ], pickler .loads (data ))
335+
278336 try :
279337 result = job ()
280338 except Exception :
281339 result = _SubprocExceptionInfo (traceback .format_exc ())
282- return pickle .dumps (result , pickle . HIGHEST_PROTOCOL )
340+ return pickler .dumps (result )
283341
284342
285343def _warm_process_pool (pool : ProcessPoolExecutor , n : int ) -> None :
0 commit comments