Skip to content

Commit a4761b7

Browse files
committed
Implement item type
1 parent 9d68818 commit a4761b7

File tree

3 files changed

+111
-50
lines changed

3 files changed

+111
-50
lines changed

src/guidellm/preprocess/item.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from collections.abc import Sequence
2+
from typing import Generic, Optional, TypeVar, Union
3+
4+
from pydantic import Field
5+
6+
from guidellm.objects.pydantic import StandardBaseModel
7+
8+
PromptT = TypeVar("PromptT")
9+
10+
11+
class Item(StandardBaseModel, Generic[PromptT]):
12+
"""
13+
Represents a single item in a dataset, containing a prompt and its associated metadata.
14+
"""
15+
16+
value: PromptT = Field(
17+
description="The prompt text or data for the item.",
18+
examples=[
19+
"What is the capital of France?",
20+
"Explain quantum computing in simple terms.",
21+
],
22+
)
23+
prompt_tokens: Optional[int] = Field(
24+
default=None, gt=0, description="Number of tokens in the prompt"
25+
)
26+
output_tokens: Optional[int] = Field(
27+
default=None, gt=0, description="Number of tokens in the output"
28+
)
29+
30+
31+
class ItemList(Sequence[Item[PromptT]]):
32+
"""
33+
Represents a list of items, each containing a prompt and its metadata.
34+
"""
35+
36+
def __init__(self, *items: Item[PromptT], shared_prefix: Optional[PromptT] = None):
37+
self.shared_prefix: Optional[PromptT] = shared_prefix
38+
self._items: list[Item[PromptT]] = list(items)
39+
40+
def __getitem__(self, key) -> Union[Item[PromptT], Sequence[Item[PromptT]]]:
41+
return self._items[key]
42+
43+
def __len__(self) -> int:
44+
return len(self._items)
45+
46+
@classmethod
47+
def from_lists(
48+
cls,
49+
prompts: list[PromptT],
50+
prompts_tokens: list[Optional[int]],
51+
outputs_tokens: list[Optional[int]],
52+
) -> "ItemList":
53+
return cls(
54+
*[
55+
Item(value=prompt, output_tokens=in_t, prompt_tokens=out_t)
56+
for prompt, in_t, out_t in zip(
57+
prompts, prompts_tokens, outputs_tokens, strict=True
58+
)
59+
]
60+
)

src/guidellm/request/loader.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1212
from transformers import PreTrainedTokenizerBase # type: ignore[import]
1313

14-
from guidellm.config import settings
1514
from guidellm.dataset import ColumnInputTypes, load_dataset
1615
from guidellm.objects import StandardBaseModel
17-
from guidellm.request.request import GenerationRequest
16+
from guidellm.preprocess.item import ItemList
1817
from guidellm.request.session import GenerativeRequestSession
1918

2019
__all__ = [
@@ -107,20 +106,13 @@ def __init__(
107106
self._preserved_iter = None
108107

109108
def __iter__(self) -> Iterator[GenerativeRequestSession]:
110-
turns = 1
111-
112-
data_iter = self._create_requests()
113-
while requests := [i for i, _ in zip(data_iter, range(turns))]:
114-
yield GenerativeRequestSession(requests)
115-
116-
def _create_requests(self) -> Iterator[GenerationRequest]:
117109
scope_create_count = 0
118110

119111
while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
120112
scope_create_count += 1
121113

122114
for item in dataset_iter:
123-
yield self._create_request(item)
115+
yield GenerativeRequestSession(self._create_items(item))
124116

125117
self._preserved_iter = None
126118

@@ -268,25 +260,17 @@ def _get_dataset_iter(
268260

269261
return dataset_iter
270262

271-
def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
272-
prompt_tokens = (
273-
item[self.column_mappings["prompt_tokens_count_column"]]
263+
def _create_items(self, item: dict[str, Any]) -> ItemList:
264+
prompts = list(item[self.column_mappings["prompt_column"]])
265+
prompt_tokens: list[Optional[int]] = (
266+
list(item[self.column_mappings["prompt_tokens_count_column"]])
274267
if "prompt_tokens_count_column" in self.column_mappings
275-
else None
268+
else [None]
276269
)
277-
output_tokens = (
278-
item[self.column_mappings["output_tokens_count_column"]]
270+
output_tokens: list[Optional[int]] = (
271+
list(item[self.column_mappings["output_tokens_count_column"]])
279272
if "output_tokens_count_column" in self.column_mappings
280-
else None
273+
else [None]
281274
)
282275

283-
return GenerationRequest(
284-
request_type=settings.preferred_route,
285-
content=item[self.column_mappings["prompt_column"]],
286-
stats=(
287-
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
288-
),
289-
constraints=(
290-
{"output_tokens": output_tokens} if output_tokens is not None else {}
291-
),
292-
)
276+
return ItemList.from_lists(prompts, prompt_tokens, output_tokens)

src/guidellm/request/session.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import itertools
22
from abc import ABC, abstractmethod
3-
from typing import Generic, TypeVar
3+
from collections.abc import Sequence
4+
from typing import Generic
45

56
from guidellm.backend.response import ResponseSummary
7+
from guidellm.config import settings
8+
from guidellm.preprocess.item import Item, ItemList
69
from guidellm.request.request import GenerationRequest
10+
from guidellm.request.types import RequestT, ResponseT
711

812
__all__ = ["GenerativeRequestSession", "RequestSession"]
913

10-
RequestT = TypeVar("RequestT")
11-
ResponseT = TypeVar("ResponseT")
12-
1314

1415
class RequestSession(ABC, Generic[RequestT, ResponseT]):
1516
"""
@@ -35,44 +36,60 @@ def complete(self) -> bool: ...
3536

3637

3738
class GenerativeRequestSession(RequestSession[GenerationRequest, ResponseSummary]):
38-
def __init__(self, prompts: list[GenerationRequest]) -> None:
39-
if not prompts:
39+
def __init__(self, items: ItemList) -> None:
40+
if len(items) < 1:
4041
raise ValueError("Prompts cannot be empty")
4142

42-
self.prompts = prompts
43-
self.responses: list[str] = []
43+
self.prompts: Sequence[Item] = items
44+
self.responses: list[Item] = []
4445

4546
def __len__(self) -> int:
4647
return len(self.prompts)
4748

4849
def get_next_request(self) -> GenerationRequest:
4950
completed_responses = len(self.responses)
50-
base_request = self.prompts[completed_responses].model_copy(deep=True)
51-
base_request.content = "".join(
51+
52+
# FIXME: Can only handle string requests
53+
content = "".join(
5254
itertools.chain.from_iterable(
53-
zip((x.content for x in self.prompts), self.responses + [""])
55+
(x.value, y.value)
56+
for x, y in zip(self.prompts, self.responses + [Item(value="")])
5457
)
5558
)
56-
base_request.stats["prompt_tokens"] = sum(
57-
x.stats["prompt_tokens"] for x in self.prompts[: completed_responses + 1]
59+
60+
prev_prompt_tokens = sum(
61+
(x.prompt_tokens or 0) + (x.output_tokens or 0) for x in self.responses
5862
)
59-
base_request.constraints["output_tokens"] = sum(
60-
x.constraints["output_tokens"]
61-
for x in self.prompts[: completed_responses + 1]
63+
prompt_tokens = (
64+
self.prompts[completed_responses].prompt_tokens or 0
65+
) + prev_prompt_tokens
66+
67+
output_tokens = self.prompts[completed_responses].output_tokens
68+
69+
return GenerationRequest(
70+
request_type=settings.preferred_route,
71+
content=content,
72+
stats=(
73+
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
74+
),
75+
constraints=(
76+
{"output_tokens": output_tokens} if output_tokens is not None else {}
77+
),
6278
)
6379

64-
return base_request
65-
6680
def get_next_delay(self) -> float:
6781
return 0.0
6882

6983
def push_response(self, response: ResponseSummary) -> None:
7084
if len(self.responses) < len(self.prompts):
71-
if response.response_output_tokens is not None:
72-
self.prompts[len(self.responses)].constraints["output_tokens"] = (
73-
response.response_output_tokens
74-
)
75-
self.responses.append(response.value)
85+
resp = Item(
86+
value=response.value,
87+
prompt_tokens=response.response_prompt_tokens
88+
or response.request_prompt_tokens,
89+
output_tokens=response.response_output_tokens
90+
or response.request_output_tokens,
91+
)
92+
self.responses.append(resp)
7693
else:
7794
raise ValueError("Response list full")
7895

0 commit comments

Comments
 (0)