2727from requests import Response
2828
2929
30+ class ModelSelectionStrategy (Enum ):
31+ ROUND_ROBIN = auto ()
32+ RANDOM = auto ()
33+
34+
3035class PromptSource (Enum ):
3136 SYNTHETIC = auto ()
3237 DATASET = auto ()
@@ -78,7 +83,8 @@ def create_llm_inputs(
7883 input_type : PromptSource ,
7984 output_format : OutputFormat ,
8085 dataset_name : str = "" ,
81- model_name : str = "" ,
86+ model_name : list = [],
87+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
8288 input_filename : Optional [Path ] = Path ("" ),
8389 starting_index : int = DEFAULT_STARTING_INDEX ,
8490 length : int = DEFAULT_LENGTH ,
@@ -194,6 +200,7 @@ def create_llm_inputs(
194200 output_tokens_stddev ,
195201 output_tokens_deterministic ,
196202 model_name ,
203+ model_selection_strategy ,
197204 )
198205 cls ._write_json_to_file (json_in_pa_format , output_dir )
199206
@@ -354,7 +361,8 @@ def _convert_generic_json_to_output_format(
354361 output_tokens_mean : int ,
355362 output_tokens_stddev : int ,
356363 output_tokens_deterministic : bool ,
357- model_name : str = "" ,
364+ model_name : list = [],
365+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
358366 ) -> Dict :
359367 if output_format == OutputFormat .OPENAI_CHAT_COMPLETIONS :
360368 output_json = cls ._convert_generic_json_to_openai_chat_completions_format (
@@ -366,6 +374,7 @@ def _convert_generic_json_to_output_format(
366374 output_tokens_stddev ,
367375 output_tokens_deterministic ,
368376 model_name ,
377+ model_selection_strategy ,
369378 )
370379 elif output_format == OutputFormat .OPENAI_COMPLETIONS :
371380 output_json = cls ._convert_generic_json_to_openai_completions_format (
@@ -377,6 +386,7 @@ def _convert_generic_json_to_output_format(
377386 output_tokens_stddev ,
378387 output_tokens_deterministic ,
379388 model_name ,
389+ model_selection_strategy ,
380390 )
381391 elif output_format == OutputFormat .VLLM :
382392 output_json = cls ._convert_generic_json_to_vllm_format (
@@ -388,6 +398,7 @@ def _convert_generic_json_to_output_format(
388398 output_tokens_stddev ,
389399 output_tokens_deterministic ,
390400 model_name ,
401+ model_selection_strategy ,
391402 )
392403 elif output_format == OutputFormat .TENSORRTLLM :
393404 output_json = cls ._convert_generic_json_to_trtllm_format (
@@ -399,6 +410,7 @@ def _convert_generic_json_to_output_format(
399410 output_tokens_stddev ,
400411 output_tokens_deterministic ,
401412 model_name ,
413+ model_selection_strategy ,
402414 )
403415 else :
404416 raise GenAIPerfException (
@@ -417,7 +429,8 @@ def _convert_generic_json_to_openai_chat_completions_format(
417429 output_tokens_mean : int ,
418430 output_tokens_stddev : int ,
419431 output_tokens_deterministic : bool ,
420- model_name : str = "" ,
432+ model_name : list = [],
433+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
421434 ) -> Dict :
422435 # TODO (TMA-1757): Implement a way to select a role for `text_input`
423436 (
@@ -436,6 +449,7 @@ def _convert_generic_json_to_openai_chat_completions_format(
436449 output_tokens_stddev ,
437450 output_tokens_deterministic ,
438451 model_name ,
452+ model_selection_strategy ,
439453 )
440454
441455 return pa_json
@@ -450,7 +464,8 @@ def _convert_generic_json_to_openai_completions_format(
450464 output_tokens_mean : int ,
451465 output_tokens_stddev : int ,
452466 output_tokens_deterministic : bool ,
453- model_name : str = "" ,
467+ model_name : list = [],
468+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
454469 ) -> Dict :
455470 (
456471 system_role_headers ,
@@ -469,6 +484,7 @@ def _convert_generic_json_to_openai_completions_format(
469484 output_tokens_stddev ,
470485 output_tokens_deterministic ,
471486 model_name ,
487+ model_selection_strategy ,
472488 )
473489
474490 return pa_json
@@ -483,7 +499,8 @@ def _convert_generic_json_to_vllm_format(
483499 output_tokens_mean : int ,
484500 output_tokens_stddev : int ,
485501 output_tokens_deterministic : bool ,
486- model_name : str = "" ,
502+ model_name : list = [],
503+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
487504 ) -> Dict :
488505 (
489506 system_role_headers ,
@@ -503,6 +520,7 @@ def _convert_generic_json_to_vllm_format(
503520 output_tokens_stddev ,
504521 output_tokens_deterministic ,
505522 model_name ,
523+ model_selection_strategy ,
506524 )
507525
508526 return pa_json
@@ -517,7 +535,8 @@ def _convert_generic_json_to_trtllm_format(
517535 output_tokens_mean : int ,
518536 output_tokens_stddev : int ,
519537 output_tokens_deterministic : bool ,
520- model_name : str = "" ,
538+ model_name : list = [],
539+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
521540 ) -> Dict :
522541 (
523542 system_role_headers ,
@@ -537,6 +556,7 @@ def _convert_generic_json_to_trtllm_format(
537556 output_tokens_stddev ,
538557 output_tokens_deterministic ,
539558 model_name ,
559+ model_selection_strategy ,
540560 )
541561
542562 return pa_json
@@ -577,6 +597,17 @@ def _determine_json_feature_roles(
577597
578598 return system_role_headers , user_role_headers , text_input_headers
579599
600+ @classmethod
601+ def _select_model_name (cls , model_name , index , model_selection_strategy ):
602+ if model_selection_strategy == ModelSelectionStrategy .ROUND_ROBIN :
603+ return model_name [index % len (model_name )]
604+ elif model_selection_strategy == ModelSelectionStrategy .RANDOM :
605+ return random .choice (model_name )
606+ else :
607+ raise GenAIPerfException (
608+ f"Model selection strategy '{ model_selection_strategy } ' is unsupported"
609+ )
610+
580611 @classmethod
581612 def _populate_openai_chat_completions_output_json (
582613 cls ,
@@ -589,11 +620,15 @@ def _populate_openai_chat_completions_output_json(
589620 output_tokens_mean : int ,
590621 output_tokens_stddev : int ,
591622 output_tokens_deterministic : bool ,
592- model_name : str = "" ,
623+ model_name : list = [],
624+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
593625 ) -> Dict :
594626 pa_json = cls ._create_empty_openai_pa_json ()
595627
596628 for index , entry in enumerate (dataset_json ["rows" ]):
629+ iter_model_name = cls ._select_model_name (
630+ model_name , index , model_selection_strategy
631+ )
597632 pa_json ["data" ].append ({"payload" : []})
598633 pa_json ["data" ][index ]["payload" ].append ({"messages" : []})
599634
@@ -613,7 +648,7 @@ def _populate_openai_chat_completions_output_json(
613648 output_tokens_mean ,
614649 output_tokens_stddev ,
615650 output_tokens_deterministic ,
616- model_name ,
651+ iter_model_name ,
617652 )
618653
619654 return pa_json
@@ -631,11 +666,15 @@ def _populate_openai_completions_output_json(
631666 output_tokens_mean : int ,
632667 output_tokens_stddev : int ,
633668 output_tokens_deterministic : bool ,
634- model_name : str = "" ,
669+ model_name : list = [],
670+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
635671 ) -> Dict :
636672 pa_json = cls ._create_empty_openai_pa_json ()
637673
638674 for index , entry in enumerate (dataset_json ["rows" ]):
675+ iter_model_name = cls ._select_model_name (
676+ model_name , index , model_selection_strategy
677+ )
639678 pa_json ["data" ].append ({"payload" : []})
640679 pa_json ["data" ][index ]["payload" ].append ({"prompt" : "" })
641680
@@ -659,7 +698,7 @@ def _populate_openai_completions_output_json(
659698 output_tokens_mean ,
660699 output_tokens_stddev ,
661700 output_tokens_deterministic ,
662- model_name ,
701+ iter_model_name ,
663702 )
664703
665704 return pa_json
@@ -677,11 +716,15 @@ def _populate_vllm_output_json(
677716 output_tokens_mean : int ,
678717 output_tokens_stddev : int ,
679718 output_tokens_deterministic : bool ,
680- model_name : str = "" ,
719+ model_name : list = [],
720+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
681721 ) -> Dict :
682722 pa_json = cls ._create_empty_vllm_pa_json ()
683723
684724 for index , entry in enumerate (dataset_json ["rows" ]):
725+ iter_model_name = cls ._select_model_name (
726+ model_name , index , model_selection_strategy
727+ )
685728 pa_json ["data" ].append ({"text_input" : ["" ]})
686729
687730 for header , content in entry .items ():
@@ -706,7 +749,7 @@ def _populate_vllm_output_json(
706749 output_tokens_mean ,
707750 output_tokens_stddev ,
708751 output_tokens_deterministic ,
709- model_name ,
752+ iter_model_name ,
710753 )
711754
712755 return pa_json
@@ -724,7 +767,8 @@ def _populate_trtllm_output_json(
724767 output_tokens_mean : int ,
725768 output_tokens_stddev : int ,
726769 output_tokens_deterministic : bool ,
727- model_name : str = "" ,
770+ model_name : list = [],
771+ model_selection_strategy : ModelSelectionStrategy = ModelSelectionStrategy .ROUND_ROBIN ,
728772 ) -> Dict :
729773 pa_json = cls ._create_empty_trtllm_pa_json ()
730774 default_max_tokens = (
@@ -733,6 +777,9 @@ def _populate_trtllm_output_json(
733777 )
734778
735779 for index , entry in enumerate (dataset_json ["rows" ]):
780+ iter_model_name = cls ._select_model_name (
781+ model_name , index , model_selection_strategy
782+ )
736783 pa_json ["data" ].append ({"text_input" : ["" ]})
737784
738785 for header , content in entry .items ():
@@ -760,7 +807,7 @@ def _populate_trtllm_output_json(
760807 output_tokens_mean ,
761808 output_tokens_stddev ,
762809 output_tokens_deterministic ,
763- model_name ,
810+ iter_model_name ,
764811 )
765812
766813 return pa_json
0 commit comments