Skip to content

Commit ff305a3

Browse files
committed
Fix most type errors in scheduler package
Signed-off-by: Jared O'Connell <[email protected]>
1 parent dd219f1 commit ff305a3

File tree

6 files changed

+80
-53
lines changed

6 files changed

+80
-53
lines changed

src/guidellm/scheduler/constraints.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,9 +1005,7 @@ def info(self) -> dict[str, Any]:
10051005
return self.model_dump()
10061006

10071007
def __call__(
1008-
self,
1009-
state: SchedulerState,
1010-
request_info: RequestInfo, # noqa: ARG002
1008+
self, state: SchedulerState, request: RequestInfo
10111009
) -> SchedulerUpdateAction:
10121010
create_exceeded = state.created_requests >= self.num_requests
10131011
processed_exceeded = state.processed_requests >= self.num_requests

src/guidellm/scheduler/environments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def sync_run_start(self) -> float:
8484
async def update_run_iteration(
8585
self,
8686
response: ResponseT | None,
87-
request: RequestT,
87+
request: RequestT | MultiTurnRequestT[RequestT],
8888
request_info: RequestInfo,
8989
state: SchedulerState,
9090
):
@@ -201,7 +201,7 @@ async def sync_run_start(self) -> float:
201201
async def update_run_iteration(
202202
self,
203203
response: ResponseT | None,
204-
request: RequestT,
204+
request: RequestT | MultiTurnRequestT[RequestT],
205205
request_info: RequestInfo,
206206
state: SchedulerState,
207207
):

src/guidellm/scheduler/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def run(
6969
) -> AsyncIterator[
7070
tuple[
7171
ResponseT | None,
72-
RequestT,
72+
RequestT | MultiTurnRequestT[RequestT],
7373
RequestInfo,
7474
SchedulerState,
7575
]

src/guidellm/scheduler/strategies.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]:
7070
description="Number of worker processes to use for this strategy",
7171
ge=0,
7272
)
73-
max_concurrency: int = Field(
74-
default=0,
73+
max_concurrency: int | None = Field(
74+
default=None,
7575
description="Maximum number of concurrent requests to allow",
7676
ge=0,
7777
)
@@ -122,8 +122,8 @@ def init_processes_timings(
122122
self.startup_duration = startup_duration
123123

124124
self._processes_request_index = Value("i", 0)
125-
self._processes_lock = Lock()
126125
self._processes_start_time = Value("d", -1.0)
126+
self._processes_lock = Lock()
127127

128128
def init_processes_start(self, start_time: float):
129129
"""
@@ -137,6 +137,8 @@ def init_processes_start(self, start_time: float):
137137
"SchedulingStrategy init_processes_start called before "
138138
"init_processes_timings"
139139
)
140+
if self._processes_start_time is None:
141+
raise RuntimeError("_processes_lock is not None but _processes_start_time is None")
140142

141143
with self._processes_lock:
142144
self._processes_start_time.value = start_time
@@ -153,6 +155,8 @@ async def get_processes_start_time(self) -> float:
153155
"SchedulingStrategy get_processes_start_time called before "
154156
"init_processes_timings"
155157
)
158+
if self._processes_start_time is None:
159+
raise RuntimeError("_processes_lock is not None but _processes_start_time is None")
156160

157161
while self._cached_processes_start_time is None:
158162
with self._processes_lock:
@@ -175,6 +179,8 @@ def next_request_index(self) -> int:
175179
"SchedulingStrategy next_request_index called before "
176180
"init_processes_timings"
177181
)
182+
if self._processes_request_index is None:
183+
raise RuntimeError("_processes_lock is not None but _processes_request_index is None")
178184

179185
with self._processes_lock:
180186
self._processes_request_index.value += 1
@@ -369,7 +375,8 @@ async def next_request_time(self, offset: int) -> float:
369375
start_time = await self.get_processes_start_time()
370376

371377
if (
372-
self.startup_duration > 0
378+
self.max_concurrency is not None
379+
and self.startup_duration > 0
373380
and (time.time() - start_time) < self.startup_duration
374381
and (current_index := self.next_request_index()) <= self.max_concurrency
375382
):
@@ -477,6 +484,8 @@ def init_processes_timings(
477484
:param startup_duration: Duration in seconds for request startup ramping
478485
"""
479486
super().init_processes_timings(worker_count, max_concurrency, startup_duration)
487+
if self._processes_lock is None:
488+
raise RuntimeError("_processes_lock is None in init_processes_timings")
480489
with self._processes_lock:
481490
self._offset = Value("d", -1.0)
482491

@@ -487,6 +496,12 @@ def init_processes_start(self, start_time: float):
487496
:param start_time: Unix timestamp when request processing should begin
488497
"""
489498
ThroughputStrategy.init_processes_start(self, start_time)
499+
500+
if self._processes_lock is None:
501+
raise RuntimeError("_processes_lock is None in init_processes_start")
502+
if self._offset is None:
503+
raise RuntimeError("_offset is None in init_processes_start; was "
504+
"init_processes_timings not called?")
490505
with self._processes_lock:
491506
self._offset.value = start_time
492507

@@ -505,6 +520,12 @@ async def next_request_time(self, offset: int) -> float:
505520

506521
next_delay = self._random.expovariate(self.rate)
507522

523+
if self._processes_lock is None:
524+
raise RuntimeError("_processes_lock is None in next_request_time; was "
525+
"init_processes_timings not called?")
526+
if self._offset is None:
527+
raise RuntimeError("_offset is None in next_request_time; was "
528+
"init_processes_timings not called?")
508529
with self._processes_lock:
509530
self._offset.value += next_delay
510531

src/guidellm/scheduler/worker.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323
bool, "Flag indicating uvloop availability for event loop optimization"
2424
] = True
2525
except ImportError:
26-
uvloop = None
26+
uvloop = None # type: ignore[assignment] # Optional dependency
2727

28-
HAS_UVLOOP: Annotated[
29-
bool, "Flag indicating uvloop availability for event loop optimization"
30-
] = False
28+
HAS_UVLOOP = False
3129

3230

3331
from guidellm.scheduler.schemas import (
@@ -84,6 +82,10 @@ def __init__(
8482
RequestT | MultiTurnRequestT[RequestT],
8583
RequestInfo,
8684
],
85+
tuple[
86+
RequestT | MultiTurnRequestT[RequestT],
87+
RequestInfo,
88+
],
8789
],
8890
backend: BackendInterface[RequestT, ResponseT],
8991
strategy: SchedulingStrategy,
@@ -201,8 +203,11 @@ async def run_async(self):
201203

202204
async def _stop_monitor(
203205
self,
204-
) -> Literal["error_event", "shutdown_event"]:
205-
"""Monitor shutdown and error events for worker termination."""
206+
) -> None:
207+
"""
208+
Monitor shutdown and error events for worker termination.
209+
:raises RuntimeError if the work process received an error signal.
210+
"""
206211
exit_key = await wait_for_sync_objects(
207212
{
208213
"error_event": self.error_event,
@@ -322,7 +327,7 @@ async def _cancel_requests_loop(self):
322327
"""Cancel all remaining queued requests until worker process terminates."""
323328
while True:
324329
try:
325-
request: RequestT
330+
request: RequestT | MultiTurnRequestT[RequestT]
326331
request_info: RequestInfo
327332
request, request_info = await self.messaging.get(
328333
timeout=self.messaging.poll_interval
@@ -345,22 +350,23 @@ async def _process_next_request(self, target_start: float):
345350
:param target_start: Unix timestamp when request should begin processing
346351
"""
347352
request: RequestT | MultiTurnRequestT[RequestT] | None = None
348-
request_info: RequestInfo | None = None
353+
request_info: RequestInfo | None
349354
response: ResponseT | None = None
350355

351356
try:
352357
# Pull request from the queue, update state, and send "pending" update
353358
request, request_info = await self.messaging.get()
354-
request_info.timings.dequeued = time.time()
355-
request_info.scheduler_node_id = self.messaging.worker_index or -1
356-
request_info.timings.targeted_start = target_start
357-
self._send_update("pending", response, request, request_info)
358-
359+
dequeued_time = time.time() # Ensure accurate dequeue timing
359360
if request is None or request_info is None:
360361
raise RuntimeError("Received invalid request or request info")
361362
if isinstance(request, list | tuple):
362363
raise NotImplementedError("Multi-turn requests are not yet supported")
363364

365+
request_info.timings.dequeued = dequeued_time
366+
request_info.scheduler_node_id = self.messaging.worker_index or -1
367+
request_info.timings.targeted_start = target_start
368+
self._send_update("pending", response, request, request_info)
369+
364370
# Schedule the request
365371
current_time = time.time()
366372
request_info.timings.scheduled_at = current_time
@@ -372,7 +378,7 @@ async def _process_next_request(self, target_start: float):
372378
# Process the request with the backend
373379
request_info.timings.resolve_start = time.time()
374380
self._send_update("in_progress", response, request, request_info)
375-
async for resp, info in self.backend.resolve(request, request_info, None):
381+
async for resp, info in await self.backend.resolve(request, request_info, None):
376382
response = resp
377383
request_info = info
378384

src/guidellm/scheduler/worker_group.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
backend: BackendInterface[RequestT, ResponseT],
8585
strategy: SchedulingStrategy,
8686
startup_duration: float,
87-
**constraints: dict[str, Constraint],
87+
**constraints: Constraint,
8888
):
8989
"""
9090
Initialize a worker process group for distributed request processing.
@@ -478,9 +478,9 @@ def __init__(
478478
num_processes=len(processes),
479479
start_time=start_time,
480480
)
481-
self._queued_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
482-
self._pending_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
483-
self._processing_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
481+
self._queued_request_ids: set[str] = set()
482+
self._pending_request_ids: set[str] = set()
483+
self._processing_request_ids: set[str] = set()
484484

485485
def requests_generator(
486486
self, requests: Iterable[RequestT | MultiTurnRequestT[RequestT]]
@@ -517,11 +517,13 @@ def requests_generator(
517517
)
518518
state_update = self._locked_update(request_info)
519519
request_info.timings.queued = time.time()
520+
if self.messaging.buffer_receive_queue is None:
521+
raise RuntimeError("buffer receive queue is None")
520522
self.messaging.buffer_receive_queue.sync_put(
521523
(None, request, request_info, state_update.state)
522524
)
523525

524-
yield (request, request_info)
526+
yield request, request_info
525527

526528
if state_update.stop_queueing:
527529
self.stop_send_requests_event.set()
@@ -530,7 +532,7 @@ def requests_generator(
530532
# Reached the end, inject a RequestsExhaustedConstraint to record
531533
self._locked_update(
532534
info=None,
533-
requests_exhausted={
535+
add_constraints={
534536
"requests_exhausted": RequestsExhaustedConstraint(
535537
num_requests=count
536538
)
@@ -610,10 +612,10 @@ def received_callback(
610612
def _locked_update(
611613
self,
612614
info: RequestInfo | None = None,
613-
**add_constraints: dict[str, Constraint],
615+
add_constraints: dict[str, Constraint] | None = None,
614616
) -> _StateUpdate:
615617
with self._update_lock:
616-
if add_constraints:
618+
if add_constraints is not None:
617619
self.constraints.update(add_constraints)
618620

619621
if info is not None:
@@ -631,34 +633,34 @@ def _locked_update(
631633

632634
def _update_state_request_counts(self, info: RequestInfo):
633635
if info.status == "queued":
634-
self._queued_requests.add(info.request_id)
635-
self._state.queued_requests = len(self._queued_requests)
636+
self._queued_request_ids.add(info.request_id)
637+
self._state.queued_requests = len(self._queued_request_ids)
636638
self._state.created_requests += 1
637639
elif info.status == "pending":
638-
self._queued_requests.remove(info.request_id)
639-
self._state.queued_requests = len(self._queued_requests)
640-
self._pending_requests.add(info.request_id)
641-
self._state.pending_requests = len(self._pending_requests)
640+
self._queued_request_ids.remove(info.request_id)
641+
self._state.queued_requests = len(self._queued_request_ids)
642+
self._pending_request_ids.add(info.request_id)
643+
self._state.pending_requests = len(self._pending_request_ids)
642644
elif info.status == "in_progress":
643-
self._pending_requests.remove(info.request_id)
644-
self._state.pending_requests = len(self._pending_requests)
645-
self._processing_requests.add(info.request_id)
646-
self._state.processing_requests = len(self._processing_requests)
645+
self._pending_request_ids.remove(info.request_id)
646+
self._state.pending_requests = len(self._pending_request_ids)
647+
self._processing_request_ids.add(info.request_id)
648+
self._state.processing_requests = len(self._processing_request_ids)
647649
elif info.status == "completed":
648-
self._processing_requests.remove(info.request_id)
649-
self._state.processing_requests = len(self._processing_requests)
650+
self._processing_request_ids.remove(info.request_id)
651+
self._state.processing_requests = len(self._processing_request_ids)
650652
self._state.processed_requests += 1
651653
self._state.successful_requests += 1
652654
elif info.status in ("errored", "cancelled"):
653-
if info.request_id in self._queued_requests:
654-
self._queued_requests.remove(info.request_id)
655-
self._state.queued_requests = len(self._queued_requests)
656-
elif info.request_id in self._pending_requests:
657-
self._pending_requests.remove(info.request_id)
658-
self._state.pending_requests = len(self._pending_requests)
659-
elif info.request_id in self._processing_requests:
660-
self._processing_requests.remove(info.request_id)
661-
self._state.processing_requests = len(self._processing_requests)
655+
if info.request_id in self._queued_request_ids:
656+
self._queued_request_ids.remove(info.request_id)
657+
self._state.queued_requests = len(self._queued_request_ids)
658+
elif info.request_id in self._pending_request_ids:
659+
self._pending_request_ids.remove(info.request_id)
660+
self._state.pending_requests = len(self._pending_request_ids)
661+
elif info.request_id in self._processing_request_ids:
662+
self._processing_request_ids.remove(info.request_id)
663+
self._state.processing_requests = len(self._processing_request_ids)
662664

663665
self._state.processed_requests += 1
664666
self._state.errored_requests += 1 if info.status == "errored" else 0

0 commit comments

Comments
 (0)