Skip to content

Commit 7efb7b1

Browse files
committed
Add basic multiturn loader support
Signed-off-by: Samuel Monson <[email protected]>
1 parent 1fa186f commit 7efb7b1

File tree

1 file changed

+29
-18
lines changed

1 file changed

+29
-18
lines changed

src/guidellm/request/loader.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def __init__(
105105
self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests
106106
self._preserved_iter = None
107107

108-
def __iter__(self) -> Iterator[GenerationRequest]:
108+
def __iter__(self) -> Iterator[list[GenerationRequest]]:
109109
scope_create_count = 0
110110

111111
while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
112112
scope_create_count += 1
113113

114114
for item in dataset_iter:
115-
yield self._create_request(item)
115+
yield self._create_requests(item)
116116

117117
self._preserved_iter = None
118118

@@ -260,25 +260,36 @@ def _get_dataset_iter(
260260

261261
return dataset_iter
262262

263-
def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
264-
prompt_tokens = (
265-
item[self.column_mappings["prompt_tokens_count_column"]]
263+
def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]:
264+
prompts = list(item[self.column_mappings["prompt_column"]])
265+
prompts_tokens: list[Optional[int]] = (
266+
list(item[self.column_mappings["prompt_tokens_count_column"]])
266267
if "prompt_tokens_count_column" in self.column_mappings
267-
else None
268+
else [None] * len(prompts)
268269
)
269-
output_tokens = (
270-
item[self.column_mappings["output_tokens_count_column"]]
270+
outputs_tokens: list[Optional[int]] = (
271+
list(item[self.column_mappings["output_tokens_count_column"]])
271272
if "output_tokens_count_column" in self.column_mappings
272-
else None
273+
else [None] * len(prompts)
273274
)
274275

275-
return GenerationRequest(
276-
request_type=settings.preferred_route,
277-
content=item[self.column_mappings["prompt_column"]],
278-
stats=(
276+
if len(prompts) != len(prompts_tokens) != len(outputs_tokens):
277+
raise ValueError(
278+
"Mismatched lengths between prompts and token counts. "
279+
f"Prompts: {len(prompts)}, Prompt Tokens: {len(prompts_tokens)}, "
280+
f"Output Tokens: {len(outputs_tokens)}"
281+
)
282+
283+
return [
284+
GenerationRequest(
285+
request_type=settings.preferred_route,
286+
content=prompt,
287+
stats=(
279288
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
280-
),
281-
constraints=(
282-
{"output_tokens": output_tokens} if output_tokens is not None else {}
283-
),
284-
)
289+
),
290+
constraints=(
291+
{"output_tokens": output_tokens} if output_tokens is not None else {}
292+
),
293+
)
294+
for prompt, prompt_tokens, output_tokens in zip(prompts, prompts_tokens, outputs_tokens)
295+
]

0 commit comments

Comments
 (0)