Skip to content

Commit fc1d6df

Browse files
committed
Pass MP context to InterProcessMessaging
1 parent 2fbf052 commit fc1d6df

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/guidellm/scheduler/worker_group.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ async def create_processes(self):
193193

194194
if settings.mp_messaging_object == "queue":
195195
self.messaging = InterProcessMessagingQueue(
196+
mp_context=self.mp_context,
196197
serialization=settings.mp_serialization,
197198
encoding=settings.mp_encoding,
198199
max_pending_size=max_pending_size,
@@ -202,6 +203,7 @@ async def create_processes(self):
202203
elif settings.mp_messaging_object == "manager_queue":
203204
self.messaging = InterProcessMessagingManagerQueue(
204205
manager=self.mp_manager,
206+
mp_context=self.mp_context,
205207
serialization=settings.mp_serialization,
206208
encoding=settings.mp_encoding,
207209
max_pending_size=max_pending_size,
@@ -211,6 +213,7 @@ async def create_processes(self):
211213
elif settings.mp_messaging_object == "pipe":
212214
self.messaging = InterProcessMessagingPipe(
213215
num_workers=num_processes,
216+
mp_context=self.mp_context,
214217
serialization=settings.mp_serialization,
215218
encoding=settings.mp_encoding,
216219
max_pending_size=max_pending_size,

src/guidellm/utils/messaging.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from abc import ABC, abstractmethod
1919
from collections.abc import Iterable
2020
from multiprocessing.connection import Connection
21-
from multiprocessing.connection import Pipe as ProcessingPipe
2221
from multiprocessing.context import BaseContext
22+
from multiprocessing.managers import SyncManager
2323
from multiprocessing.synchronize import Event as ProcessingEvent
2424
from threading import Event as ThreadingEvent
2525
from typing import Any, Callable, Generic, Protocol, TypeVar
@@ -94,6 +94,7 @@ class InterProcessMessaging(Generic[SendMessageT, ReceiveMessageT], ABC):
9494

9595
def __init__(
9696
self,
97+
mp_context: BaseContext | None = None,
9798
serialization: SerializationTypesAlias = "dict",
9899
encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None,
99100
max_pending_size: int | None = None,
@@ -116,6 +117,7 @@ def __init__(
116117
:param worker_index: Index identifying this worker in the process group
117118
"""
118119
self.worker_index: int | None = worker_index
120+
self.mp_context = mp_context or multiprocessing.get_context()
119121
self.serialization = serialization
120122
self.encoding = encoding
121123
self.max_pending_size = max_pending_size
@@ -433,6 +435,7 @@ class InterProcessMessagingQueue(InterProcessMessaging[SendMessageT, ReceiveMess
433435

434436
def __init__(
435437
self,
438+
mp_context: BaseContext | None = None,
436439
serialization: SerializationTypesAlias = "dict",
437440
encoding: EncodingTypesAlias = None,
438441
max_pending_size: int | None = None,
@@ -457,8 +460,10 @@ def __init__(
457460
:param worker_index: Index identifying this worker in the process group
458461
:param pending_queue: Multiprocessing queue for sending messages
459462
:param done_queue: Multiprocessing queue for receiving completed messages
463+
:param context: Multiprocessing context for creating queues
460464
"""
461465
super().__init__(
466+
mp_context=mp_context,
462467
serialization=serialization,
463468
encoding=encoding,
464469
max_pending_size=max_pending_size,
@@ -468,10 +473,10 @@ def __init__(
468473
poll_interval=poll_interval,
469474
worker_index=worker_index,
470475
)
471-
self.pending_queue = pending_queue or multiprocessing.Queue(
476+
self.pending_queue = pending_queue or self.mp_context.Queue(
472477
maxsize=max_pending_size or 0
473478
)
474-
self.done_queue = done_queue or multiprocessing.Queue(
479+
self.done_queue = done_queue or self.mp_context.Queue(
475480
maxsize=max_done_size or 0
476481
)
477482

@@ -485,6 +490,7 @@ def create_worker_copy(
485490
:return: Configured queue messaging instance for the specified worker
486491
"""
487492
copy_args = {
493+
"mp_context": self.mp_context,
488494
"serialization": self.serialization,
489495
"encoding": self.encoding,
490496
"max_pending_size": self.max_pending_size,
@@ -657,7 +663,8 @@ class InterProcessMessagingManagerQueue(
657663

658664
def __init__(
659665
self,
660-
manager: BaseContext,
666+
manager: SyncManager,
667+
mp_context: BaseContext | None = None,
661668
serialization: SerializationTypesAlias = "dict",
662669
encoding: EncodingTypesAlias = None,
663670
max_pending_size: int | None = None,
@@ -686,6 +693,7 @@ def __init__(
686693
messages
687694
"""
688695
super().__init__(
696+
mp_context=mp_context,
689697
serialization=serialization,
690698
encoding=encoding,
691699
max_pending_size=max_pending_size,
@@ -694,8 +702,8 @@ def __init__(
694702
max_buffer_receive_size=max_buffer_receive_size,
695703
poll_interval=poll_interval,
696704
worker_index=worker_index,
697-
pending_queue=pending_queue or manager.Queue(maxsize=max_pending_size or 0),
698-
done_queue=done_queue or manager.Queue(maxsize=max_done_size or 0),
705+
pending_queue=pending_queue or manager.Queue(maxsize=max_pending_size or 0), # type: ignore [assignment]
706+
done_queue=done_queue or manager.Queue(maxsize=max_done_size or 0), # type: ignore [assignment]
699707
)
700708

701709
def create_worker_copy(
@@ -709,6 +717,7 @@ def create_worker_copy(
709717
"""
710718
copy_args = {
711719
"manager": None,
720+
"mp_context": self.mp_context,
712721
"serialization": self.serialization,
713722
"encoding": self.encoding,
714723
"max_pending_size": self.max_pending_size,
@@ -759,6 +768,7 @@ class InterProcessMessagingPipe(InterProcessMessaging[SendMessageT, ReceiveMessa
759768
def __init__(
760769
self,
761770
num_workers: int,
771+
mp_context: BaseContext | None = None,
762772
serialization: SerializationTypesAlias = "dict",
763773
encoding: EncodingTypesAlias = None,
764774
max_pending_size: int | None = None,
@@ -784,6 +794,7 @@ def __init__(
784794
:param pipe: Existing pipe connection for worker-specific instances
785795
"""
786796
super().__init__(
797+
mp_context=mp_context,
787798
serialization=serialization,
788799
encoding=encoding,
789800
max_pending_size=max_pending_size,
@@ -797,7 +808,7 @@ def __init__(
797808

798809
if pipe is None:
799810
self.pipes: list[tuple[Connection, Connection]] = [
800-
ProcessingPipe(duplex=True) for _ in range(num_workers)
811+
self.mp_context.Pipe(duplex=True) for _ in range(num_workers)
801812
]
802813
else:
803814
self.pipes: list[tuple[Connection, Connection]] = [pipe]
@@ -813,6 +824,7 @@ def create_worker_copy(
813824
"""
814825
copy_args = {
815826
"num_workers": self.num_workers,
827+
"mp_context": self.mp_context,
816828
"serialization": self.serialization,
817829
"encoding": self.encoding,
818830
"max_pending_size": self.max_pending_size,

tests/unit/utils/test_messaging.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ class TestInterProcessMessagingQueue:
187187
def valid_instances(self, multiprocessing_contexts, request):
188188
"""Fixture providing test data for InterProcessMessagingQueue."""
189189
constructor_args = request.param
190-
instance = InterProcessMessagingQueue(**constructor_args, poll_interval=0.01)
191190
manager, context = multiprocessing_contexts
191+
instance = InterProcessMessagingQueue(
192+
**constructor_args, poll_interval=0.01, mp_context=context
193+
)
192194

193195
return instance, constructor_args, manager, context
194196

0 commit comments

Comments
 (0)