Skip to content

Commit 0c6f679

Browse files
committed
Implement tests and fixes for scheduler to finish of scheduler package for now
1 parent 0c0a8f2 commit 0c6f679

File tree

12 files changed

+438
-68
lines changed

12 files changed

+438
-68
lines changed

src/guidellm/scheduler/environment.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@
2929
SchedulerState,
3030
)
3131
from guidellm.scheduler.strategy import SchedulingStrategy
32+
from guidellm.utils import InfoMixin
3233

3334
__all__ = ["Environment", "NonDistributedEnvironment"]
3435

3536

36-
class Environment(ABC, Generic[RequestT, ResponseT]):
37+
class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin):
3738
"""
3839
Abstract base for scheduler execution environments.
3940
@@ -86,6 +87,7 @@ async def update_run_iteration(
8687
response: ResponseT | None,
8788
request: RequestT,
8889
request_info: ScheduledRequestInfo[MeasuredRequestTimingsT],
90+
state: SchedulerState,
8991
):
9092
"""
9193
Update environment state with completed request iteration.
@@ -101,7 +103,7 @@ async def update_run_iteration(
101103
...
102104

103105
@abstractmethod
104-
async def sync_run_error(self, err: Exception):
106+
async def sync_run_error(self, err: list[Exception] | Exception):
105107
"""
106108
Handle and propagate errors across all nodes.
107109
@@ -144,13 +146,11 @@ class NonDistributedEnvironment(Environment):
144146
distributed coordination. Implements the Environment interface with minimal
145147
synchronization overhead for local testing, development, and single-machine
146148
benchmarking.
147-
148-
:ivar run_err: Exception that occurred during execution, if any.
149149
"""
150150

151151
def __init__(self):
152152
"""Initialize with no stored errors."""
153-
self.run_err: Exception = None
153+
self.run_errors: list[Exception] = []
154154

155155
async def sync_run_params(
156156
self,
@@ -181,6 +181,7 @@ async def update_run_iteration(
181181
response: ResponseT | None,
182182
request: RequestT,
183183
request_info: ScheduledRequestInfo[MeasuredRequestTimingsT],
184+
state: SchedulerState,
184185
):
185186
"""
186187
No-op for single-node execution.
@@ -196,7 +197,8 @@ async def sync_run_error(self, err: Exception):
196197
197198
:param err: The exception that occurred during execution.
198199
"""
199-
self.run_err = err
200+
err = [err] if not isinstance(err, list) else err
201+
self.run_errors.extend(err)
200202

201203
async def sync_run_end(
202204
self,
@@ -214,8 +216,13 @@ async def sync_run_end(
214216
:return: Empty iterator since there are no remote nodes.
215217
:raises Exception: Any error stored during execution via sync_run_error.
216218
"""
217-
if self.run_err:
218-
raise self.run_err
219-
# Return empty async iterator for non-distributed environment
219+
if self.run_errors:
220+
if len(self.run_errors) == 1:
221+
raise self.run_errors[0]
222+
else:
223+
raise RuntimeError(
224+
f"Errors occurred during execution: {self.run_errors}"
225+
)
226+
220227
return
221-
yield
228+
yield # needed to force generator compilation

src/guidellm/scheduler/objects.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from __future__ import annotations
2323

2424
from abc import ABC, abstractmethod
25-
from collections.abc import AsyncIterator, Iterable
25+
from collections.abc import AsyncIterator
2626
from typing import (
2727
Any,
2828
Generic,
@@ -54,7 +54,10 @@
5454
RequestT = TypeVar("RequestT")
5555
MultiTurnRequestT = TypeAliasType(
5656
"MultiTurnRequestT",
57-
Iterable[Union[RequestT, tuple[RequestT, float]]],
57+
Union[
58+
list[Union[RequestT, tuple[RequestT, float]]],
59+
tuple[Union[RequestT, tuple[RequestT, float]]],
60+
],
5861
type_params=(RequestT,),
5962
)
6063
ResponseT = TypeVar("ResponseT")

src/guidellm/scheduler/scheduler.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,7 @@ async def run(
109109
"""
110110
with self.thread_lock:
111111
worker_group: (
112-
WorkerProcessGroup[
113-
BackendT, RequestT, MeasuredRequestTimingsT, ResponseT
114-
]
115-
| None
112+
WorkerProcessGroup[RequestT, MeasuredRequestTimingsT, ResponseT] | None
116113
) = None
117114

118115
# Any issues during the run will raise an error (local or remote),
@@ -131,7 +128,7 @@ async def run(
131128

132129
# Setup the worker group, sync start with the environment
133130
worker_group = WorkerProcessGroup[
134-
BackendT, RequestT, MeasuredRequestTimingsT, ResponseT
131+
RequestT, MeasuredRequestTimingsT, ResponseT
135132
](
136133
backend=backend,
137134
requests=local_requests,
@@ -154,13 +151,13 @@ async def run(
154151
)
155152
yield response, request, request_info, state
156153
except Exception as err: # noqa: BLE001
157-
env.sync_run_error(err)
154+
await env.sync_run_error(err)
158155
finally:
159156
# Ensure all worker processes are cleaned up for error or completion
160157
if worker_group is not None:
161158
err = await worker_group.shutdown()
162159
if err is not None:
163-
env.sync_run_error(err)
160+
await env.sync_run_error(err)
164161

165162
# Ensure any errors are raised and all responses
166163
# are yielded for aggregation on the primary node

src/guidellm/scheduler/worker.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import asyncio
1414
import time
15-
from collections.abc import Generator, Iterable
15+
from collections.abc import Generator
1616
from multiprocessing import Queue
1717
from multiprocessing.synchronize import Barrier as ProcessingBarrier
1818
from multiprocessing.synchronize import Event as ProcessingEvent
@@ -112,6 +112,7 @@ def __init__(
112112
]
113113
] = None
114114
self.requests_canceled: ThreadingEvent = None
115+
self.pull_requests_stopped: ThreadingEvent = None
115116
self.pull_task: asyncio.Task = None
116117
self.push_task: asyncio.Task = None
117118

@@ -243,6 +244,7 @@ async def _initialize_requests_processing(self):
243244
)
244245
self.pending_updates_queue = culsans.Queue()
245246
self.requests_canceled = ThreadingEvent()
247+
self.pull_requests_stopped = ThreadingEvent()
246248

247249
# Start background tasks for queue management
248250
self.pull_task = asyncio.create_task(
@@ -351,7 +353,7 @@ async def _process_next_request(self):
351353
request_info=request_info,
352354
)
353355

354-
if isinstance(request, Iterable) and not isinstance(request, (str, bytes)):
356+
if isinstance(request, (list, tuple)):
355357
raise NotImplementedError("Multi-turn requests are not yet supported")
356358

357359
# Calculate when to start processing request
@@ -373,9 +375,8 @@ async def _process_next_request(self):
373375
request=request,
374376
request_info=request_info,
375377
)
376-
async for resp, info in self.backend.resolve(request, request_info, None):
378+
async for resp in self.backend.resolve(request, request_info, None):
377379
response = resp
378-
request_info = info
379380

380381
# Complete
381382
request_info.scheduler_timings.resolve_end = time.time()
@@ -460,7 +461,6 @@ async def _handle_request_update(
460461

461462
async def _cancel_pending_requests(self):
462463
while True:
463-
# All requests will be on the queue by now, loop until we can't get anymore
464464
try:
465465
request, request_info = await asyncio.wait_for(
466466
self.pending_requests_queue.async_get(), timeout=self.poll_intervals
@@ -474,7 +474,9 @@ async def _cancel_pending_requests(self):
474474
request_info=request_info,
475475
)
476476
except (culsans.QueueEmpty, asyncio.TimeoutError):
477-
break
477+
if self.pull_requests_stopped.is_set():
478+
# No more requests will be put on the Queue
479+
break
478480

479481
def _pull_requests_generator(self) -> Generator:
480482
last_check = time.time()
@@ -491,14 +493,16 @@ def _pull_requests_generator(self) -> Generator:
491493
pass # No update available, continue polling
492494
except culsans.QueueShutDown:
493495
break
494-
except Exception: # noqa: BLE001
496+
except Exception: # noqa: BLE001, S110
495497
pass
496498

497499
if time.time() - last_check > self.poll_intervals:
498500
# Yield to allow cancel/error/stop checks in wrapper
499501
last_check = time.time()
500502
yield None
501503

504+
self.pull_requests_stopped.set()
505+
502506
def _push_updates_generator(self) -> Generator:
503507
last_check = time.time()
504508

@@ -514,7 +518,7 @@ def _push_updates_generator(self) -> Generator:
514518
pass # No update available, continue polling
515519
except culsans.QueueShutDown:
516520
break
517-
except Exception: # noqa: BLE001
521+
except Exception: # noqa: BLE001, S110
518522
pass
519523

520524
if time.time() - last_check > self.poll_intervals:

src/guidellm/scheduler/worker_group.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def _update_state(
367367
"completed, errored, cancelled."
368368
)
369369

370+
state.end_time = time.time() # Always update for last time update received
370371
actions = {
371372
name: const(state, info) for name, const in self.constraints.items()
372373
}
@@ -465,11 +466,6 @@ def _populate_requests_create_iterator(
465466
else self.requests
466467
)
467468

468-
if self.infinite_requests is not False and isinstance(self.requests, Iterable):
469-
# Out of requests and infinite set to True or set to default
470-
# Create new iterator out of the Iterable
471-
return iter(self.requests)
472-
473469
if self.infinite_requests is True and isinstance(self.requests, Iterator):
474470
# Out of requests and infinite set to True, but request_iter is Iterator
475471
# Cannot create new, raise RuntimeError
@@ -478,6 +474,11 @@ def _populate_requests_create_iterator(
478474
"infinite_requests is set to True"
479475
)
480476

477+
if self.infinite_requests is not False and isinstance(self.requests, Iterable):
478+
# Out of requests and infinite set to True or set to default
479+
# Create new iterator out of the Iterable
480+
return iter(self.requests)
481+
481482
# Either infinite is False for Iterable or Iterator
482483
# or infinite is None (default) for Iterator
483484
# So, return None to stop

0 commit comments

Comments
 (0)