11from collections .abc import Iterator
22from dataclasses import dataclass , field
33from itertools import starmap
4- from queue import Queue , ShutDown
4+ from queue import Queue
55from threading import Condition , Thread
66from types import TracebackType
77from typing import TYPE_CHECKING , Generic
@@ -199,11 +199,20 @@ def _run(self) -> None:
199199 self ._cv .notify_all ()
200200
201201
202+ class _Shutdown :
203+ """A sentinel to signal worker threads to exit.
204+
205+ In python 3.13+, we can use queue.ShutDown directly.
206+ """
207+
208+ pass
209+
210+
202211def _work (task_queue : Queue [ThreadedTask ]) -> None : # type: ignore[type-arg]
203212 while True :
204- try :
205- task = task_queue . get ()
206- except ShutDown :
213+ task = task_queue . get ()
214+ if isinstance ( task , _Shutdown ):
215+ task_queue . task_done ()
207216 break
208217 try :
209218 task ._run ()
@@ -215,7 +224,7 @@ def _work(task_queue: Queue[ThreadedTask]) -> None: # type: ignore[type-arg]
215224
216225
217226class SingleThreadedBackend :
218- _task_queue : Queue [ThreadedTask ] | None # type: ignore[type-arg]
227+ _task_queue : Queue [ThreadedTask | _Shutdown ] | None # type: ignore[type-arg]
219228 _thread : Thread | None
220229
221230 def __init__ (self ) -> None :
@@ -230,7 +239,7 @@ def __enter__(self) -> "RunningSingleThreadedBackend":
230239 args = (self ._task_queue ,),
231240 )
232241 self ._thread .start ()
233- return RunningSingleThreadedBackend (self . _task_queue )
242+ return RunningSingleThreadedBackend (self )
234243
235244 def __exit__ (
236245 self ,
@@ -239,32 +248,32 @@ def __exit__(
239248 traceback : TracebackType | None ,
240249 ) -> None :
241250 assert self ._thread and self ._task_queue
242- self ._task_queue .shutdown ( )
251+ self ._task_queue .put ( _Shutdown () )
243252 self ._thread .join ()
244253 self ._task_queue = None
245254 self ._thread = None
246255
247256
248257class RunningSingleThreadedBackend :
249- queue : Queue [ ThreadedTask ] # type: ignore[type-arg]
258+ _backend : SingleThreadedBackend
250259
251- def __init__ (self , queue : Queue [ ThreadedTask ] ): # type: ignore[type-arg]
252- self .queue = queue
260+ def __init__ (self , backend : SingleThreadedBackend ): # type: ignore[type-arg]
261+ self ._backend = backend
253262
254263 def compute (
255264 self ,
256265 item : Computable [InputT , ResultT ],
257266 / ,
258267 error_policy : ErrorPolicy = ErrorPolicy (),
259268 ) -> ThreadedTask [InputT , ResultT ]:
260- if self .queue . is_shutdown :
269+ if self ._backend . _task_queue is None :
261270 raise RuntimeError (
262271 "Cannot compute on a backend that has been exited from its context manager"
263272 )
264273 if hasattr (item , "__next__" ):
265274 raise TypeError ("Computable items must be iterables, not iterators" )
266275 task = ThreadedTask (item , error_policy )
267- self .queue .put (task )
276+ self ._backend . _task_queue .put (task )
268277 return task
269278
270279 def wait_all (self , progress : bool = False ) -> None :
@@ -278,7 +287,8 @@ def wait_all(self, progress: bool = False) -> None:
278287 if progress :
279288 raise NotImplementedError ("Progress bars are not yet implemented" )
280289 else :
281- self .queue .join ()
290+ if self ._backend ._task_queue :
291+ self ._backend ._task_queue .join ()
282292
283293
284294if TYPE_CHECKING :
0 commit comments