@@ -48,6 +48,7 @@ def test_text_generation(self):
4848 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
4949 )
5050
51+ @hide_stdout ()
5152 def test_text_generation_empty_cache (self ):
5253 mid = "arnir0/Tiny-LLM"
5354 data = get_untrained_model_with_inputs (mid , add_second_input = True )
@@ -69,6 +70,28 @@ def test_text_generation_empty_cache(self):
6970 got = ep .module ()(** torch_deepcopy (inputs ))
7071 self .assertEqualArrayAny (expected , got )
7172
73+ @hide_stdout ()
74+ def test_text_generation_batch1 (self ):
75+ mid = "arnir0/Tiny-LLM"
76+ data = get_untrained_model_with_inputs (mid , add_second_input = True )
77+ model , inputs = data ["model" ], data ["inputs" ]
78+ self .assertIn ("inputs_batch1" , data )
79+ empty_inputs = torch_deepcopy (data ["inputs_batch1" ])
80+ model (** torch_deepcopy (empty_inputs ))
81+ expected = model (** torch_deepcopy (inputs ))
82+ self .assertEqual (
83+ {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
84+ )
85+ with torch_export_patches (patch_transformers = True , verbose = 1 ):
86+ ep = torch .export .export (
87+ model ,
88+ (),
89+ kwargs = torch_deepcopy (inputs ),
90+ dynamic_shapes = use_dyn_not_str (data ["dynamic_shapes" ]),
91+ )
92+ got = ep .module ()(** torch_deepcopy (inputs ))
93+ self .assertEqualArrayAny (expected , got )
94+
7295 @hide_stdout ()
7396 def test_automatic_speech_recognition_float32 (self ):
7497 mid = "openai/whisper-tiny"
0 commit comments