1010from onnx_diagnostic .torch_models .hghub .model_inputs import get_untrained_model_with_inputs
1111from onnx_diagnostic .torch_export_patches import torch_export_patches
1212from onnx_diagnostic .torch_export_patches .patch_inputs import use_dyn_not_str
13+ from onnx_diagnostic .export .shape_helper import make_fake_with_dynamic_dimensions
1314
1415
1516class TestTasksTextGeneration (ExtTestCase ):
1617 @hide_stdout ()
1718 @requires_transformers ("4.53" )
1819 @requires_torch ("2.7.99" )
19- def test_tet_generation_gemma3_for_causallm (self ):
20+ def test_text_generation_gemma3_for_causallm (self ):
2021 mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
2122 data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
2223 self .assertEqual (data ["task" ], "text-generation" )
@@ -31,20 +32,38 @@ def test_tet_generation_gemma3_for_causallm(self):
3132 @hide_stdout ()
3233 @requires_transformers ("4.53" )
3334 @requires_torch ("2.7.99" )
34- def test_itext_generation_phi_3_mini_128k_instruct (self ):
35+ def test_text_generation_phi_3_mini_128k_instruct (self ):
3536 mid = "microsoft/Phi-3-mini-128k-instruct"
3637 data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
3738 self .assertEqual (data ["task" ], "text-generation" )
3839 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
39- print ("--" , self .string_type (inputs , with_shape = True ))
40- print ("--" , self .string_type (ds ))
4140 model (** torch_deepcopy (inputs ))
4241 model (** data ["inputs2" ])
4342 with torch_export_patches (patch_transformers = True , verbose = 10 , patch_torch = False ):
4443 torch .export .export (
4544 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
4645 )
4746
47+ @hide_stdout ()
48+ @requires_transformers ("4.53" )
49+ @requires_torch ("2.7.99" )
50+ def test_text_generation_tiny_llm (self ):
51+ mid = "arnir0/Tiny-LLM"
52+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
53+ self .assertEqual (data ["task" ], "text-generation" )
54+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
55+ expected = model (** torch_deepcopy (inputs ))
56+ model (** data ["inputs2" ])
57+ fake = make_fake_with_dynamic_dimensions (inputs , dynamic_shapes = ds )[0 ]
58+ with torch_export_patches (patch_transformers = True , verbose = 10 , patch_torch = False ):
59+ ep = torch .export .export (
60+ model , (), kwargs = fake , dynamic_shapes = use_dyn_not_str (ds ), strict = False
61+ )
62+ # print(ep)
63+ got = ep .module ()(** inputs )
64+ self .assertEqualAny (expected .past_key_values , got .past_key_values )
65+ self .assertEqualArray (expected .logits , got .logits )
66+
4867
4968if __name__ == "__main__" :
5069 unittest .main (verbosity = 2 )
0 commit comments