18
18
from abc import ABC , abstractmethod
19
19
from collections .abc import Iterable
20
20
from multiprocessing .connection import Connection
21
- from multiprocessing .connection import Pipe as ProcessingPipe
22
21
from multiprocessing .context import BaseContext
22
+ from multiprocessing .managers import SyncManager
23
23
from multiprocessing .synchronize import Event as ProcessingEvent
24
24
from threading import Event as ThreadingEvent
25
25
from typing import Any , Callable , Generic , Protocol , TypeVar
@@ -94,6 +94,7 @@ class InterProcessMessaging(Generic[SendMessageT, ReceiveMessageT], ABC):
94
94
95
95
def __init__ (
96
96
self ,
97
+ mp_context : BaseContext | None = None ,
97
98
serialization : SerializationTypesAlias = "dict" ,
98
99
encoding : EncodingTypesAlias | list [EncodingTypesAlias ] = None ,
99
100
max_pending_size : int | None = None ,
@@ -116,6 +117,7 @@ def __init__(
116
117
:param worker_index: Index identifying this worker in the process group
117
118
"""
118
119
self .worker_index : int | None = worker_index
120
+ self .mp_context = mp_context or multiprocessing .get_context ()
119
121
self .serialization = serialization
120
122
self .encoding = encoding
121
123
self .max_pending_size = max_pending_size
@@ -433,6 +435,7 @@ class InterProcessMessagingQueue(InterProcessMessaging[SendMessageT, ReceiveMess
433
435
434
436
def __init__ (
435
437
self ,
438
+ mp_context : BaseContext | None = None ,
436
439
serialization : SerializationTypesAlias = "dict" ,
437
440
encoding : EncodingTypesAlias = None ,
438
441
max_pending_size : int | None = None ,
@@ -457,8 +460,10 @@ def __init__(
457
460
:param worker_index: Index identifying this worker in the process group
458
461
:param pending_queue: Multiprocessing queue for sending messages
459
462
:param done_queue: Multiprocessing queue for receiving completed messages
463
+ :param context: Multiprocessing context for creating queues
460
464
"""
461
465
super ().__init__ (
466
+ mp_context = mp_context ,
462
467
serialization = serialization ,
463
468
encoding = encoding ,
464
469
max_pending_size = max_pending_size ,
@@ -468,10 +473,10 @@ def __init__(
468
473
poll_interval = poll_interval ,
469
474
worker_index = worker_index ,
470
475
)
471
- self .pending_queue = pending_queue or multiprocessing .Queue (
476
+ self .pending_queue = pending_queue or self . mp_context .Queue (
472
477
maxsize = max_pending_size or 0
473
478
)
474
- self .done_queue = done_queue or multiprocessing .Queue (
479
+ self .done_queue = done_queue or self . mp_context .Queue (
475
480
maxsize = max_done_size or 0
476
481
)
477
482
@@ -485,6 +490,7 @@ def create_worker_copy(
485
490
:return: Configured queue messaging instance for the specified worker
486
491
"""
487
492
copy_args = {
493
+ "mp_context" : self .mp_context ,
488
494
"serialization" : self .serialization ,
489
495
"encoding" : self .encoding ,
490
496
"max_pending_size" : self .max_pending_size ,
@@ -657,7 +663,8 @@ class InterProcessMessagingManagerQueue(
657
663
658
664
def __init__ (
659
665
self ,
660
- manager : BaseContext ,
666
+ manager : SyncManager ,
667
+ mp_context : BaseContext | None = None ,
661
668
serialization : SerializationTypesAlias = "dict" ,
662
669
encoding : EncodingTypesAlias = None ,
663
670
max_pending_size : int | None = None ,
@@ -686,6 +693,7 @@ def __init__(
686
693
messages
687
694
"""
688
695
super ().__init__ (
696
+ mp_context = mp_context ,
689
697
serialization = serialization ,
690
698
encoding = encoding ,
691
699
max_pending_size = max_pending_size ,
@@ -694,8 +702,8 @@ def __init__(
694
702
max_buffer_receive_size = max_buffer_receive_size ,
695
703
poll_interval = poll_interval ,
696
704
worker_index = worker_index ,
697
- pending_queue = pending_queue or manager .Queue (maxsize = max_pending_size or 0 ),
698
- done_queue = done_queue or manager .Queue (maxsize = max_done_size or 0 ),
705
+ pending_queue = pending_queue or manager .Queue (maxsize = max_pending_size or 0 ), # type: ignore [assignment]
706
+ done_queue = done_queue or manager .Queue (maxsize = max_done_size or 0 ), # type: ignore [assignment]
699
707
)
700
708
701
709
def create_worker_copy (
@@ -709,6 +717,7 @@ def create_worker_copy(
709
717
"""
710
718
copy_args = {
711
719
"manager" : None ,
720
+ "mp_context" : self .mp_context ,
712
721
"serialization" : self .serialization ,
713
722
"encoding" : self .encoding ,
714
723
"max_pending_size" : self .max_pending_size ,
@@ -759,6 +768,7 @@ class InterProcessMessagingPipe(InterProcessMessaging[SendMessageT, ReceiveMessa
759
768
def __init__ (
760
769
self ,
761
770
num_workers : int ,
771
+ mp_context : BaseContext | None = None ,
762
772
serialization : SerializationTypesAlias = "dict" ,
763
773
encoding : EncodingTypesAlias = None ,
764
774
max_pending_size : int | None = None ,
@@ -784,6 +794,7 @@ def __init__(
784
794
:param pipe: Existing pipe connection for worker-specific instances
785
795
"""
786
796
super ().__init__ (
797
+ mp_context = mp_context ,
787
798
serialization = serialization ,
788
799
encoding = encoding ,
789
800
max_pending_size = max_pending_size ,
@@ -797,7 +808,7 @@ def __init__(
797
808
798
809
if pipe is None :
799
810
self .pipes : list [tuple [Connection , Connection ]] = [
800
- ProcessingPipe (duplex = True ) for _ in range (num_workers )
811
+ self . mp_context . Pipe (duplex = True ) for _ in range (num_workers )
801
812
]
802
813
else :
803
814
self .pipes : list [tuple [Connection , Connection ]] = [pipe ]
@@ -813,6 +824,7 @@ def create_worker_copy(
813
824
"""
814
825
copy_args = {
815
826
"num_workers" : self .num_workers ,
827
+ "mp_context" : self .mp_context ,
816
828
"serialization" : self .serialization ,
817
829
"encoding" : self .encoding ,
818
830
"max_pending_size" : self .max_pending_size ,
0 commit comments