@@ -105,14 +105,14 @@ def __init__(
105105 self .preserve_iter_state = iter_type == "infinite" # ensure no caching requests
106106 self ._preserved_iter = None
107107
108- def __iter__ (self ) -> Iterator [GenerationRequest ]:
108+ def __iter__ (self ) -> Iterator [list [ GenerationRequest ] ]:
109109 scope_create_count = 0
110110
111111 while (dataset_iter := self ._get_dataset_iter (scope_create_count )) is not None :
112112 scope_create_count += 1
113113
114114 for item in dataset_iter :
115- yield self ._create_request (item )
115+ yield self ._create_requests (item )
116116
117117 self ._preserved_iter = None
118118
@@ -260,25 +260,36 @@ def _get_dataset_iter(
260260
261261 return dataset_iter
262262
263- def _create_request (self , item : dict [str , Any ]) -> GenerationRequest :
264- prompt_tokens = (
265- item [self .column_mappings ["prompt_tokens_count_column" ]]
263+ def _create_requests (self , item : dict [str , Any ]) -> list [GenerationRequest ]:
264+ prompts = list (item [self .column_mappings ["prompt_column" ]])
265+ prompts_tokens : list [Optional [int ]] = (
266+ list (item [self .column_mappings ["prompt_tokens_count_column" ]])
266267 if "prompt_tokens_count_column" in self .column_mappings
267- else None
268+ else [ None ] * len ( prompts )
268269 )
269- output_tokens = (
270- item [self .column_mappings ["output_tokens_count_column" ]]
270+ outputs_tokens : list [ Optional [ int ]] = (
271+ list ( item [self .column_mappings ["output_tokens_count_column" ]])
271272 if "output_tokens_count_column" in self .column_mappings
272- else None
273+ else [ None ] * len ( prompts )
273274 )
274275
275- return GenerationRequest (
276- request_type = settings .preferred_route ,
277- content = item [self .column_mappings ["prompt_column" ]],
278- stats = (
276+ if len (prompts ) != len (prompts_tokens ) != len (outputs_tokens ):
277+ raise ValueError (
278+ "Mismatched lengths between prompts and token counts. "
279+ f"Prompts: { len (prompts )} , Prompt Tokens: { len (prompts_tokens )} , "
280+ f"Output Tokens: { len (outputs_tokens )} "
281+ )
282+
283+ return [
284+ GenerationRequest (
285+ request_type = settings .preferred_route ,
286+ content = prompt ,
287+ stats = (
279288 {"prompt_tokens" : prompt_tokens } if prompt_tokens is not None else {}
280- ),
281- constraints = (
282- {"output_tokens" : output_tokens } if output_tokens is not None else {}
283- ),
284- )
289+ ),
290+ constraints = (
291+ {"output_tokens" : output_tokens } if output_tokens is not None else {}
292+ ),
293+ )
294+ for prompt , prompt_tokens , output_tokens in zip (prompts , prompts_tokens , outputs_tokens )
295+ ]
0 commit comments