Skip to content

Commit 3ac4df6

Browse files
committed
Implement worker support for multiturn
Signed-off-by: Samuel Monson <[email protected]>
1 parent 220377e commit 3ac4df6

File tree

5 files changed

+124
-49
lines changed

5 files changed

+124
-49
lines changed

src/guidellm/request/loader.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests
106106
self._preserved_iter = None
107107

108-
def __iter__(self) -> Iterator[list[GenerationRequest]]:
108+
def __iter__(self) -> Iterator[list[tuple[GenerationRequest, float]]]:
109109
scope_create_count = 0
110110

111111
while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
@@ -260,7 +260,9 @@ def _get_dataset_iter(
260260

261261
return dataset_iter
262262

263-
def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]:
263+
def _create_requests(
264+
self, item: dict[str, Any]
265+
) -> list[tuple[GenerationRequest, float]]:
264266
prompts = list(item[self.column_mappings["prompt_column"]])
265267
prompts_tokens: list[Optional[int]] = (
266268
list(item[self.column_mappings["prompt_tokens_count_column"]])
@@ -281,15 +283,24 @@ def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]:
281283
)
282284

283285
return [
284-
GenerationRequest(
285-
request_type=settings.preferred_route,
286-
content=prompt,
287-
stats=(
288-
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
289-
),
290-
constraints=(
291-
{"output_tokens": output_tokens} if output_tokens is not None else {}
286+
(
287+
GenerationRequest(
288+
request_type=settings.preferred_route,
289+
content=prompt,
290+
stats=(
291+
{"prompt_tokens": prompt_tokens}
292+
if prompt_tokens is not None
293+
else {}
294+
),
295+
constraints=(
296+
{"output_tokens": output_tokens}
297+
if output_tokens is not None
298+
else {}
299+
),
292300
),
301+
0.0, # TODO: delay
302+
)
303+
for prompt, prompt_tokens, output_tokens in zip(
304+
prompts, prompts_tokens, outputs_tokens
293305
)
294-
for prompt, prompt_tokens, output_tokens in zip(prompts, prompts_tokens, outputs_tokens)
295306
]

src/guidellm/scheduler/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
from .objects import (
1616
BackendInterface,
1717
BackendT,
18+
HistoryT,
1819
MeasuredRequestTimings,
1920
MultiTurnRequestT,
21+
MultiTurnT,
2022
RequestSchedulerTimings,
2123
RequestT,
2224
ResponseT,
25+
ScheduledRequestAugmentation,
2326
ScheduledRequestInfo,
2427
SchedulerMessagingPydanticRegistry,
2528
SchedulerState,
2629
SchedulerUpdateAction,
2730
SchedulerUpdateActionProgress,
31+
TurnT,
2832
)
2933
from .scheduler import Scheduler
3034
from .strategies import (
@@ -56,6 +60,7 @@
5660
"ConstraintInitializer",
5761
"ConstraintsInitializerFactory",
5862
"Environment",
63+
"HistoryT",
5964
"LastCompletionRequestTimings",
6065
"MaxDurationConstraint",
6166
"MaxErrorRateConstraint",
@@ -64,13 +69,15 @@
6469
"MaxNumberConstraint",
6570
"MeasuredRequestTimings",
6671
"MultiTurnRequestT",
72+
"MultiTurnT",
6773
"NoDelayRequestTimings",
6874
"NonDistributedEnvironment",
6975
"PoissonRateRequestTimings",
7076
"PydanticConstraintInitializer",
7177
"RequestSchedulerTimings",
7278
"RequestT",
7379
"ResponseT",
80+
"ScheduledRequestAugmentation",
7481
"ScheduledRequestInfo",
7582
"ScheduledRequestTimings",
7683
"Scheduler",
@@ -84,6 +91,7 @@
8491
"StrategyType",
8592
"SynchronousStrategy",
8693
"ThroughputStrategy",
94+
"TurnT",
8795
"UnserializableConstraintInitializer",
8896
"WorkerProcess",
8997
"WorkerProcessGroup",

src/guidellm/scheduler/objects.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Literal,
2020
Protocol,
2121
TypeVar,
22-
Union,
2322
)
2423

2524
from pydantic import Field, computed_field
@@ -35,34 +34,50 @@
3534
__all__ = [
3635
"BackendInterface",
3736
"BackendT",
37+
"HistoryT",
3838
"MeasuredRequestTimings",
3939
"MultiTurnRequestT",
40+
"MultiTurnT",
4041
"RequestSchedulerTimings",
4142
"RequestT",
4243
"ResponseT",
44+
"ScheduledRequestAugmentation",
4345
"ScheduledRequestInfo",
4446
"SchedulerMessagingPydanticRegistry",
4547
"SchedulerState",
4648
"SchedulerUpdateAction",
4749
"SchedulerUpdateActionProgress",
50+
"TurnT",
4851
]
4952

5053
RequestT = TypeVar("RequestT")
5154
"""Generic request object type for scheduler processing."""
5255

56+
# TODO: Remove
57+
MultiTurnRequestT = RequestT
58+
5359
ResponseT = TypeVar("ResponseT")
5460
"""Generic response object type returned by backend processing."""
5561

56-
MultiTurnRequestT = TypeAliasType(
57-
"MultiTurnRequestT",
58-
Union[
59-
list[Union[RequestT, tuple[RequestT, float]]],
60-
tuple[Union[RequestT, tuple[RequestT, float]]],
61-
],
62+
TurnT = TypeAliasType(
63+
"TurnT",
64+
tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"],
65+
type_params=(RequestT,),
66+
)
67+
68+
MultiTurnT = TypeAliasType(
69+
"MultiTurnT",
70+
list[TurnT[RequestT]],
6271
type_params=(RequestT,),
6372
)
6473
"""Multi-turn request structure supporting conversation history with optional delays."""
6574

75+
HistoryT = TypeAliasType(
76+
"HistoryT",
77+
list[tuple[RequestT, ResponseT]],
78+
type_params=(RequestT, ResponseT),
79+
)
80+
6681

6782
class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]):
6883
"""
@@ -71,6 +86,21 @@ class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]):
7186
"""
7287

7388

89+
@SchedulerMessagingPydanticRegistry.register()
90+
class ScheduledRequestAugmentation(StandardBaseModel):
91+
"""
92+
Adjustments to scheduler logic for a paired request.
93+
"""
94+
95+
post_requeue_delay: float = Field(
96+
description=(
97+
"Delay in seconds to wait after a request to "
98+
"queue the next request in the conversation."
99+
),
100+
default=0.0,
101+
)
102+
103+
74104
@SchedulerMessagingPydanticRegistry.register()
75105
class RequestSchedulerTimings(StandardBaseModel):
76106
"""

src/guidellm/scheduler/worker.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@
3131

3232
from guidellm.scheduler.objects import (
3333
BackendInterface,
34+
HistoryT,
3435
MultiTurnRequestT,
36+
MultiTurnT,
3537
RequestT,
3638
ResponseT,
39+
ScheduledRequestAugmentation,
3740
ScheduledRequestInfo,
3841
SchedulerMessagingPydanticRegistry,
3942
)
@@ -118,6 +121,9 @@ def __init__(
118121
self.startup_completed = False
119122
self.backend_started = False
120123
self.messaging_started = False
124+
self.turns_queue: list[
125+
tuple[HistoryT[RequestT, ResponseT], MultiTurnT[RequestT]]
126+
] = []
121127

122128
def run(self):
123129
"""
@@ -302,16 +308,19 @@ async def _cancel_requests_loop(self):
302308
self._send_update("cancelled", None, request, request_info)
303309

304310
async def _process_next_request(self):
305-
request: RequestT | MultiTurnRequestT[RequestT] | None = None
311+
request: RequestT | None = None
306312
request_info: ScheduledRequestInfo | None = None
307313
response: ResponseT | None = None
314+
aug: ScheduledRequestAugmentation | None = None
308315

309316
try:
310317
# Pull request from the queue
311-
request, request_info = await self.messaging.get()
312-
313-
if isinstance(request, (list, tuple)):
314-
raise NotImplementedError("Multi-turn requests are not yet supported")
318+
history, conversation = (
319+
self.turns_queue.pop(0)
320+
if self.turns_queue
321+
else ([], await self.messaging.get())
322+
)
323+
request, aug, request_info = conversation.pop(0)
315324

316325
# Calculate targeted start and set pending state for request
317326
request_info.scheduler_node_id = self.messaging.worker_index
@@ -341,6 +350,12 @@ async def _process_next_request(self):
341350
request_info.scheduler_timings.resolve_end = time.time()
342351
self._send_update("completed", response, request, request_info)
343352

353+
# If multi-turn, queue up next turn(s)
354+
# TODO: Move to callback and support delay
355+
if conversation: # more turns to process
356+
history.append((request, response))
357+
self.turns_queue.append((history, conversation))
358+
344359
response = request = request_info = None
345360
except asyncio.CancelledError:
346361
# Handle cancellation

src/guidellm/scheduler/worker_group.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
from guidellm.scheduler.objects import (
2727
BackendInterface,
2828
MultiTurnRequestT,
29+
MultiTurnT,
2930
RequestT,
3031
ResponseT,
32+
ScheduledRequestAugmentation,
3133
ScheduledRequestInfo,
3234
SchedulerMessagingPydanticRegistry,
3335
SchedulerState,
@@ -471,9 +473,9 @@ def __init__(
471473

472474
def requests_generator(
473475
self,
474-
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None,
475-
cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None,
476-
) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]:
476+
requests: Iterable[Iterable[tuple[RequestT, float]]] | None,
477+
cycle_requests: Iterable[Iterable[tuple[RequestT, float]]] | None,
478+
) -> Generator[MultiTurnT[RequestT], None, None]:
477479
"""
478480
Generate request-info pairs for worker processing with constraint evaluation.
479481
@@ -494,31 +496,40 @@ def _iter():
494496
while True:
495497
yield from cycle_requests
496498

497-
count = 0
498-
request_info: ScheduledRequestInfo = None
499+
count: int = 0
500+
stop_queueing: bool = False
501+
502+
def _turn_iter(requests_chain: Iterable[tuple[RequestT, float]]):
503+
nonlocal count, stop_queueing
504+
for request, delay in requests_chain:
505+
count += 1
506+
507+
if hasattr(request, "request_id"):
508+
request_id = request.request_id
509+
elif hasattr(request, "id"):
510+
request_id = request.id
511+
else:
512+
request_id = str(uuid.uuid4())
513+
request_augmentation = ScheduledRequestAugmentation(
514+
post_requeue_delay=delay
515+
)
516+
request_info: ScheduledRequestInfo = ScheduledRequestInfo(
517+
request_id=request_id,
518+
status="queued",
519+
scheduler_process_id=0,
520+
scheduler_start_time=self.start_time,
521+
)
522+
state_update = self._locked_update(request_info)
523+
yield (request, request_augmentation, request_info)
524+
525+
if state_update.stop_queueing:
526+
stop_queueing = True
527+
return
528+
499529
for request_chain in _iter():
500-
if isinstance(request_chain, (list, tuple)):
501-
request = request_chain[0]
502-
else:
503-
request = request_chain
504-
count += 1
505-
506-
if hasattr(request, "request_id"):
507-
request_id = request.request_id
508-
elif hasattr(request, "id"):
509-
request_id = request.id
510-
else:
511-
request_id = str(uuid.uuid4())
512-
request_info: ScheduledRequestInfo = ScheduledRequestInfo(
513-
request_id=request_id,
514-
status="queued",
515-
scheduler_process_id=0,
516-
scheduler_start_time=self.start_time,
517-
)
518-
state_update = self._locked_update(request_info)
519-
yield (request, request_info)
530+
yield list(_turn_iter(request_chain))
520531

521-
if state_update.stop_queueing:
532+
if stop_queueing:
522533
self.stop_send_requests_event.set()
523534
return
524535

0 commit comments

Comments
 (0)