Skip to content

Commit 9d68818

Browse files
committed
Implement initial multiturn support
1 parent c91b3ef commit 9d68818

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

src/guidellm/request/loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,20 @@ def __init__(
107107
self._preserved_iter = None
108108

109109
def __iter__(self) -> Iterator[GenerativeRequestSession]:
110+
turns = 1
111+
112+
data_iter = self._create_requests()
113+
while requests := [i for i, _ in zip(data_iter, range(turns))]:
114+
yield GenerativeRequestSession(requests)
115+
116+
def _create_requests(self) -> Iterator[GenerationRequest]:
110117
scope_create_count = 0
111118

112119
while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
113120
scope_create_count += 1
114121

115122
for item in dataset_iter:
116-
yield GenerativeRequestSession(self._create_request(item))
123+
yield self._create_request(item)
117124

118125
self._preserved_iter = None
119126

src/guidellm/request/session.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from abc import ABC, abstractmethod
23
from typing import Generic, TypeVar
34

@@ -34,22 +35,47 @@ def complete(self) -> bool: ...
3435

3536

3637
class GenerativeRequestSession(RequestSession[GenerationRequest, ResponseSummary]):
37-
def __init__(self, request: GenerationRequest) -> None:
38-
self.request = request
39-
self._complete = False
38+
def __init__(self, prompts: list[GenerationRequest]) -> None:
39+
if not prompts:
40+
raise ValueError("Prompts cannot be empty")
41+
42+
self.prompts = prompts
43+
self.responses: list[str] = []
4044

4145
def __len__(self) -> int:
42-
return 1
46+
return len(self.prompts)
4347

4448
def get_next_request(self) -> GenerationRequest:
45-
return self.request
49+
completed_responses = len(self.responses)
50+
base_request = self.prompts[completed_responses].model_copy(deep=True)
51+
base_request.content = "".join(
52+
itertools.chain.from_iterable(
53+
zip((x.content for x in self.prompts), self.responses + [""])
54+
)
55+
)
56+
base_request.stats["prompt_tokens"] = sum(
57+
x.stats["prompt_tokens"] for x in self.prompts[: completed_responses + 1]
58+
)
59+
base_request.constraints["output_tokens"] = sum(
60+
x.constraints["output_tokens"]
61+
for x in self.prompts[: completed_responses + 1]
62+
)
63+
64+
return base_request
4665

4766
def get_next_delay(self) -> float:
4867
return 0.0
4968

50-
def push_response(self, response: ResponseSummary) -> None: # noqa: ARG002
51-
self._complete = True
69+
def push_response(self, response: ResponseSummary) -> None:
70+
if len(self.responses) < len(self.prompts):
71+
if response.response_output_tokens is not None:
72+
self.prompts[len(self.responses)].constraints["output_tokens"] = (
73+
response.response_output_tokens
74+
)
75+
self.responses.append(response.value)
76+
else:
77+
raise ValueError("Response list full")
5278

5379
@property
5480
def complete(self) -> bool:
55-
return self._complete
81+
return len(self.responses) >= len(self.prompts)

0 commit comments

Comments
 (0)