Skip to content

Commit 02bfdd1

Browse files
committed
Fix hand in item creation code
1 parent 3013678 commit 02bfdd1

File tree

2 files changed

+17
-25
lines changed

2 files changed

+17
-25
lines changed

src/guidellm/preprocess/item.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,3 @@ def __getitem__(self, key) -> Union[Item[PromptT], Sequence[Item[PromptT]]]:
4242

4343
def __len__(self) -> int:
4444
return len(self._items)
45-
46-
@classmethod
47-
def from_lists(
48-
cls,
49-
prompts: list[PromptT],
50-
prompts_tokens: list[Optional[int]],
51-
outputs_tokens: list[Optional[int]],
52-
) -> "ItemList":
53-
return cls(
54-
*[
55-
Item(value=prompt, output_tokens=in_t, prompt_tokens=out_t)
56-
for prompt, in_t, out_t in zip(
57-
prompts, prompts_tokens, outputs_tokens, strict=True
58-
)
59-
]
60-
)

src/guidellm/request/loader.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from guidellm.dataset import ColumnInputTypes, load_dataset
1515
from guidellm.objects import StandardBaseModel
16-
from guidellm.preprocess.item import ItemList
16+
from guidellm.preprocess.item import Item, ItemList
1717
from guidellm.request.session import GenerativeRequestSession
1818

1919
__all__ = [
@@ -261,16 +261,24 @@ def _get_dataset_iter(
261261
return dataset_iter
262262

263263
def _create_items(self, item: dict[str, Any]) -> ItemList:
264-
prompts = list(item[self.column_mappings["prompt_column"]])
265-
prompt_tokens: list[Optional[int]] = (
266-
list(item[self.column_mappings["prompt_tokens_count_column"]])
264+
prompts = item[self.column_mappings["prompt_column"]]
265+
prompt_tokens = (
266+
item[self.column_mappings["prompt_tokens_count_column"]]
267267
if "prompt_tokens_count_column" in self.column_mappings
268-
else [None]
268+
else None
269269
)
270-
output_tokens: list[Optional[int]] = (
271-
list(item[self.column_mappings["output_tokens_count_column"]])
270+
output_tokens = (
271+
item[self.column_mappings["output_tokens_count_column"]]
272272
if "output_tokens_count_column" in self.column_mappings
273-
else [None]
273+
else None
274274
)
275275

276-
return ItemList.from_lists(prompts, prompt_tokens, output_tokens)
276+
items = (
277+
Item(value=prompt, output_tokens=out_t, prompt_tokens=in_t)
278+
for prompt, in_t, out_t in zip(
279+
prompts if isinstance(prompts, list) else [prompts],
280+
prompt_tokens if isinstance(prompt_tokens, list) else [prompt_tokens],
281+
output_tokens if isinstance(output_tokens, list) else [output_tokens],
282+
)
283+
)
284+
return ItemList(*items)

0 commit comments

Comments
 (0)