Skip to content

Commit 2fbf052

Browse files
committed
Fixes for new refactor runs
1 parent ac07cc6 commit 2fbf052

File tree

4 files changed

+142
-29
lines changed

4 files changed

+142
-29
lines changed

src/guidellm/scheduler/worker.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,12 @@ def _task_done(task):
253253
pending_tasks.add(request_task)
254254
request_task.add_done_callback(_task_done)
255255
except (asyncio.CancelledError, Exception) as err:
256-
await self._cancel_remaining_requests(pending_tasks, all_requests_processed)
257-
await self.messaging.stop()
258-
await self.backend.process_shutdown()
256+
if self.startup_completed:
257+
await self._cancel_remaining_requests(
258+
pending_tasks, all_requests_processed
259+
)
260+
await self.messaging.stop()
261+
await self.backend.process_shutdown()
259262

260263
raise err
261264

src/guidellm/scheduler/worker_group.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144

145145
async def create_processes(self):
146146
"""
147-
Start the processes for the worker process group.
147+
Create and initialize worker processes for distributed request processing.
148148
149149
Sets up multiprocessing infrastructure and worker processes based on
150150
strategy constraints, backend capabilities, and system configuration.
@@ -399,11 +399,6 @@ class _WorkerGroupState(Generic[RequestT, MeasuredRequestTimingsT, ResponseT]):
399399
Handles request generation, state updates, constraint evaluation, and
400400
coordination between worker processes. Provides thread-safe state management
401401
with request lifecycle tracking and constraint-based termination logic.
402-
403-
:param start_time: Unix timestamp when processing should begin
404-
:param num_processes: Number of worker processes in the group
405-
:param constraints: Named constraints for controlling execution behavior
406-
:param shutdown_event: Multiprocessing event for coordinated shutdown
407402
"""
408403

409404
def __init__(
@@ -414,6 +409,15 @@ def __init__(
414409
constraints: dict[str, Constraint],
415410
shutdown_event: Event,
416411
):
412+
"""
413+
Initialize worker group state management.
414+
415+
:param start_time: Unix timestamp when processing should begin
416+
:param num_processes: Number of worker processes in the group
417+
:param processes: List of worker process instances
418+
:param constraints: Named constraints for controlling execution behavior
419+
:param shutdown_event: Multiprocessing event for coordinated shutdown
420+
"""
417421
self._start_time = start_time
418422
self._update_lock: threading.Lock = threading.Lock()
419423
self._state: SchedulerState = SchedulerState(
@@ -527,7 +531,7 @@ def update_callback_receive(
527531
)
528532

529533
def stop_callback_receive(
530-
self, messaging: InterProcessMessaging, pending: bool, is_empty: bool
534+
self, messaging: InterProcessMessaging, pending: bool, queue_empty: int
531535
) -> bool:
532536
"""
533537
Determine if message receiving should stop based on system state.
@@ -537,12 +541,12 @@ def stop_callback_receive(
537541
538542
:param messaging: Inter-process messaging instance
539543
:param pending: Whether operations are still pending
540-
:param is_empty: Whether receive queues are empty
544+
:param queue_empty: The number of times the queue has reported empty in a row
541545
:return: True if message receiving should stop, False otherwise
542546
"""
543547
return (
544548
not pending
545-
and is_empty # all updates pulled off
549+
and queue_empty >= InterProcessMessaging.STOP_REQUIRED_QUEUE_EMPTY
546550
and messaging.send_stopped_event.is_set() # No more requests will be added
547551
and self._shutdown_event.is_set() # processing should stop
548552
and all(

src/guidellm/utils/messaging.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,10 @@ def check_stop(pending: bool, queue_empty: int) -> bool:
400400
return (
401401
not pending
402402
and queue_empty >= self.STOP_REQUIRED_QUEUE_EMPTY
403-
and self.shutdown_event.is_set()
404-
or any(event.is_set() for event in stop_events)
403+
and (
404+
self.shutdown_event.is_set()
405+
or any(event.is_set() for event in stop_events)
406+
)
405407
)
406408

407409
return check_stop

tests/unit/scheduler/test_worker_group.py

Lines changed: 119 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class TestWorkerProcessGroup:
8787
"requests": None,
8888
"cycle_requests": ["request1", "request2", "request3"],
8989
"strategy": SynchronousStrategy(),
90-
"constraints": {"max_requests": MaxNumberConstraint(max_num=10)},
90+
"constraints": {"max_num": MaxNumberConstraint(max_num=10)},
9191
},
9292
{
9393
"requests": None,
@@ -185,33 +185,137 @@ def test_initialization(self, valid_instances):
185185
assert instance._state is None
186186
assert instance.messaging is None
187187

188+
@pytest.mark.sanity
189+
@pytest.mark.parametrize(
190+
("requests", "cycle_requests", "expected_error"),
191+
[
192+
(None, None, ValueError),
193+
([], iter([]), ValueError), # cycle_requests as Iterator
194+
(None, iter(["req1"]), ValueError), # cycle_requests as Iterator
195+
],
196+
ids=["no_requests", "cycle_as_iterator_empty", "cycle_as_iterator_data"],
197+
)
198+
def test_invalid_initialization_values(
199+
self, requests, cycle_requests, expected_error
200+
):
201+
"""Test WorkerProcessGroup with invalid initialization values."""
202+
with pytest.raises(expected_error):
203+
WorkerProcessGroup(
204+
requests=requests,
205+
cycle_requests=cycle_requests,
206+
backend=MockBackend(),
207+
strategy=SynchronousStrategy(),
208+
constraints={},
209+
)
210+
211+
@pytest.mark.sanity
212+
def test_invalid_initialization_missing(self):
213+
"""Test WorkerProcessGroup initialization without required fields."""
214+
with pytest.raises(TypeError):
215+
WorkerProcessGroup()
216+
188217
@pytest.mark.smoke
189-
# @async_timeout(5)
218+
@async_timeout(10)
190219
@pytest.mark.asyncio
191220
async def test_lifecycle(self, valid_instances: tuple[WorkerProcessGroup, dict]):
192221
"""Test the lifecycle methods of WorkerProcessGroup."""
193-
instance, _ = valid_instances
222+
instance, constructor_args = valid_instances
194223

195224
# Test create processes
196225
await instance.create_processes()
197-
# TODO: check valid process creation
226+
227+
# Check valid process creation
228+
assert instance.mp_context is not None
229+
assert instance.mp_manager is not None
230+
assert instance.processes is not None
231+
assert len(instance.processes) > 0
232+
assert all(proc.is_alive() for proc in instance.processes)
233+
assert instance.startup_barrier is not None
234+
assert instance.shutdown_event is not None
235+
assert instance.error_event is not None
236+
assert instance.requests_completed_event is not None
237+
assert instance.messaging is not None
198238

199239
# Test start
200240
start_time = time.time() + 0.1
201241
await instance.start(start_time=start_time)
202-
# TODO: check valid start behavior
242+
243+
# Check valid start behavior
244+
assert instance.messaging is not None
245+
assert instance._state is not None
246+
assert instance._state._start_time == start_time
247+
assert instance._state._state.num_processes == len(instance.processes)
248+
assert not instance.error_event.is_set()
203249

204250
# Test iter updates
205-
updates = {}
206-
async for resp, req, info, state in instance.request_updates():
207-
pass
208-
# TODO: validate correct updates based on requests, cycle_requests, and constraints
251+
updates_list = []
252+
responses_count = 0
253+
254+
async for (
255+
response,
256+
request,
257+
request_info,
258+
scheduler_state,
259+
) in instance.request_updates():
260+
updates_list.append((response, request, request_info, scheduler_state))
261+
if response is not None:
262+
responses_count += 1
263+
264+
# Validate request info structure
265+
assert hasattr(request_info, "request_id")
266+
assert hasattr(request_info, "status")
267+
valid_statuses = [
268+
"queued",
269+
"in_progress",
270+
"completed",
271+
"errored",
272+
"cancelled",
273+
]
274+
assert request_info.status in valid_statuses
275+
276+
# Validate state structure
277+
assert hasattr(scheduler_state, "created_requests")
278+
assert hasattr(scheduler_state, "processed_requests")
279+
assert hasattr(scheduler_state, "successful_requests")
280+
assert scheduler_state.created_requests >= 0
281+
assert scheduler_state.processed_requests >= 0
282+
assert scheduler_state.successful_requests >= 0
283+
284+
# Validate correctness of all updates
285+
if constructor_args.get("requests") is not None:
286+
assert len(updates_list) == 2 * len(constructor_args["requests"]), (
287+
"Should have received updates for all requests"
288+
)
289+
if constructor_args.get("constraints", {}).get("max_num") is not None:
290+
assert (
291+
len(updates_list)
292+
== 2 * constructor_args["constraints"]["max_num"].max_num
293+
), "Should not have received more updates than max_num constraint"
294+
295+
assert len(updates_list) > 0, "Should have received at least one update"
296+
297+
# Constraints should be satisfied
298+
for constraint_name, _ in constructor_args["constraints"].items():
299+
constraint_check = (
300+
"max" in constraint_name.lower()
301+
or "duration" in constraint_name.lower()
302+
)
303+
if constraint_check:
304+
assert scheduler_state.end_processing_time is not None, (
305+
f"Should have stopped processing due to {constraint_name}"
306+
)
209307

210308
# Test shutdown
211-
await instance.shutdown()
212-
print(
213-
f"\nRequests summary: created={state.created_requests}, queued={state.queued_requests}, processing={state.processing_requests}, processed={state.processed_requests}, successful={state.successful_requests}, cancelled={state.cancelled_requests}, errored={state.errored_requests}"
309+
exceptions = await instance.shutdown()
310+
311+
# Check valid shutdown behavior
312+
assert isinstance(exceptions, list), "Shutdown should return list of exceptions"
313+
assert instance.messaging is None, "Messaging should be cleared after shutdown"
314+
assert instance._state is None, "State should be cleared after shutdown"
315+
assert instance.processes is None, "Processes should be cleared after shutdown"
316+
assert instance.mp_manager is None, (
317+
"MP manager should be cleared after shutdown"
318+
)
319+
assert instance.mp_context is None, (
320+
"MP context should be cleared after shutdown"
214321
)
215-
print(resp)
216-
print(info)
217-
# TODO: check valid shutdown behavior

0 commit comments

Comments
 (0)