@@ -573,6 +573,7 @@ def validate_model(
573573 if verbose :
574574 print (f"[validate_model] new inputs: { string_type (data ['inputs' ])} " )
575575 print (f"[validate_model] new dynamic_hapes: { string_type (data ['dynamic_shapes' ])} " )
576+ # NOTE: The dynamic_shapes is always the same across inputs sets
576577 if inputs2 :
577578 assert (
578579 "inputs2" in data
@@ -583,6 +584,14 @@ def validate_model(
583584 model = data ["model" ],
584585 dynamic_shapes = data ["dynamic_shapes" ],
585586 )
587+ # NOTE: text-generation tests 3rd inputs for multi-turn conversation
588+ if "inputs3" in data :
589+ data ["inputs3" ], _ = filter_inputs (
590+ data ["inputs3" ],
591+ drop_names = drop_inputs ,
592+ model = data ["model" ],
593+ dynamic_shapes = data ["dynamic_shapes" ],
594+ )
586595
587596 if not empty (dtype ):
588597 if isinstance (dtype , str ):
@@ -594,6 +603,8 @@ def validate_model(
594603 summary ["model_dtype" ] = str (dtype )
595604 if "inputs2" in data :
596605 data ["inputs2" ] = to_any (data ["inputs2" ], dtype ) # type: ignore
606+ if "inputs3" in data :
607+ data ["inputs3" ] = to_any (data ["inputs3" ], dtype ) # type: ignore
597608
598609 if not empty (device ):
599610 if verbose :
@@ -603,6 +614,8 @@ def validate_model(
603614 summary ["model_device" ] = str (device )
604615 if "inputs2" in data :
605616 data ["inputs2" ] = to_any (data ["inputs2" ], device ) # type: ignore
617+ if "inputs3" in data :
618+ data ["inputs3" ] = to_any (data ["inputs3" ], device ) # type: ignore
606619
607620 for k in ["task" , "size" , "n_weights" ]:
608621 summary [f"model_{ k .replace ('_' ,'' )} " ] = data [k ]
@@ -638,10 +651,12 @@ def validate_model(
638651 _validate_do_run_model (
639652 data , summary , "inputs" , "run" , "run_expected" , verbose , repeat , warmup , quiet
640653 )
641- if inputs2 :
642- _validate_do_run_model (
643- data , summary , "inputs2" , "run2" , "run_expected2" , verbose , 1 , 0 , quiet
644- )
654+ _validate_do_run_model (
655+ data , summary , "inputs2" , "run2" , "run_expected2" , verbose , 1 , 0 , quiet
656+ )
657+ _validate_do_run_model (
658+ data , summary , "inputs3" , "run3" , "run_expected3" , verbose , 1 , 0 , quiet
659+ )
645660
646661 if exporter :
647662 print (
@@ -899,6 +914,10 @@ def _validate_do_run_model(
899914 if verbose :
900915 print (f"[validate_model] -- run the model inputs={ key !r} ..." )
901916 print (f"[validate_model] { key } ={ string_type (data [key ], with_shape = True )} " )
917+ if key not in data :
918+ if verbose :
919+ print (f"[validate_model] input; { key !r} not defined, skip." )
920+ return
902921 # We make a copy of the input just in case the model modifies them inplace
903922 hash_inputs = string_type (data [key ], with_shape = True )
904923 inputs = torch_deepcopy (data [key ])
@@ -1329,6 +1348,9 @@ def _mk(key, flavour=flavour):
13291348 keys = [("inputs" , "run_expected" , "" )]
13301349 if inputs2 :
13311350 keys .append (("inputs2" , "run_expected2" , "2" ))
1351+ # text-generation tests multi-turn conversation as 3rd inputs
1352+ if "inputs3" in data :
1353+ keys .append (("inputs3" , "run_expected3" , "3" ))
13321354 for k_input , k_expected , suffix in keys :
13331355 # make_feeds
13341356 if verbose :
0 commit comments