|
13 | 13 |
|
14 | 14 | from guidellm.dataset import ColumnInputTypes, load_dataset
|
15 | 15 | from guidellm.objects import StandardBaseModel
|
16 |
| -from guidellm.preprocess.item import ItemList |
| 16 | +from guidellm.preprocess.item import Item, ItemList |
17 | 17 | from guidellm.request.session import GenerativeRequestSession
|
18 | 18 |
|
19 | 19 | __all__ = [
|
@@ -261,16 +261,24 @@ def _get_dataset_iter(
|
261 | 261 | return dataset_iter
|
262 | 262 |
|
263 | 263 | def _create_items(self, item: dict[str, Any]) -> ItemList:
|
264 |
| - prompts = list(item[self.column_mappings["prompt_column"]]) |
265 |
| - prompt_tokens: list[Optional[int]] = ( |
266 |
| - list(item[self.column_mappings["prompt_tokens_count_column"]]) |
| 264 | + prompts = item[self.column_mappings["prompt_column"]] |
| 265 | + prompt_tokens = ( |
| 266 | + item[self.column_mappings["prompt_tokens_count_column"]] |
267 | 267 | if "prompt_tokens_count_column" in self.column_mappings
|
268 |
| - else [None] |
| 268 | + else None |
269 | 269 | )
|
270 |
| - output_tokens: list[Optional[int]] = ( |
271 |
| - list(item[self.column_mappings["output_tokens_count_column"]]) |
| 270 | + output_tokens = ( |
| 271 | + item[self.column_mappings["output_tokens_count_column"]] |
272 | 272 | if "output_tokens_count_column" in self.column_mappings
|
273 |
| - else [None] |
| 273 | + else None |
274 | 274 | )
|
275 | 275 |
|
276 |
| - return ItemList.from_lists(prompts, prompt_tokens, output_tokens) |
| 276 | + items = ( |
| 277 | + Item(value=prompt, output_tokens=out_t, prompt_tokens=in_t) |
| 278 | + for prompt, in_t, out_t in zip( |
| 279 | + prompts if isinstance(prompts, list) else [prompts], |
| 280 | + prompt_tokens if isinstance(prompt_tokens, list) else [prompt_tokens], |
| 281 | + output_tokens if isinstance(output_tokens, list) else [output_tokens], |
| 282 | + ) |
| 283 | + ) |
| 284 | + return ItemList(*items) |
0 commit comments