11import unittest
22import torch
3- from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , has_transformers , has_torch
3+ from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , has_transformers
44from onnx_diagnostic .torch_models .hghub .model_inputs import get_untrained_model_with_inputs
55from onnx_diagnostic .torch_export_patches import bypass_export_some_errors
66from onnx_diagnostic .torch_export_patches .patch_inputs import use_dyn_not_str
@@ -10,11 +10,27 @@ class TestTasks(ExtTestCase):
1010 @hide_stdout ()
1111 def test_text2text_generation (self ):
1212 mid = "sshleifer/tiny-marian-en-de"
13- data = get_untrained_model_with_inputs (mid , verbose = 1 )
13+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
14+ self .assertEqual (data ["task" ], "text2text-generation" )
1415 self .assertIn ((data ["size" ], data ["n_weights" ]), [(473928 , 118482 )])
1516 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
1617 raise unittest .SkipTest (f"not working for { mid !r} " )
1718 model (** inputs )
19+ model (** data ["inputs2" ])
20+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
21+ torch .export .export (
22+ model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
23+ )
24+
25+ @hide_stdout ()
26+ def test_text_generation (self ):
27+ mid = "arnir0/Tiny-LLM"
28+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
29+ self .assertEqual (data ["task" ], "text-generation" )
30+ self .assertIn ((data ["size" ], data ["n_weights" ]), [(51955968 , 12988992 )])
31+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
32+ model (** inputs )
33+ model (** data ["inputs2" ])
1834 with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
1935 torch .export .export (
2036 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
@@ -23,9 +39,11 @@ def test_text2text_generation(self):
2339 @hide_stdout ()
2440 def test_automatic_speech_recognition (self ):
2541 mid = "openai/whisper-tiny"
26- data = get_untrained_model_with_inputs (mid , verbose = 1 )
42+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
43+ self .assertEqual (data ["task" ], "automatic-speech-recognition" )
2744 self .assertIn ((data ["size" ], data ["n_weights" ]), [(132115968 , 33028992 )])
2845 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
46+ model (** data ["inputs2" ])
2947 Dim = torch .export .Dim
3048 self .maxDiff = None
3149 self .assertIn ("{0:Dim(batch),1:DYN(seq_length)}" , self .string_type (ds ))
@@ -90,27 +108,15 @@ def test_automatic_speech_recognition(self):
90108 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
91109 )
92110
93- @hide_stdout ()
94- def test_imagetext2text_generation (self ):
95- mid = "HuggingFaceM4/tiny-random-idefics"
96- data = get_untrained_model_with_inputs (mid , verbose = 1 )
97- self .assertIn ((data ["size" ], data ["n_weights" ]), [(12742888 , 3185722 )])
98- model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
99- model (** inputs )
100- if not has_torch ("2.10" ):
101- raise unittest .SkipTest ("sym_max does not work with dynamic dimension" )
102- with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
103- torch .export .export (
104- model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
105- )
106-
107111 @hide_stdout ()
108112 def test_fill_mask (self ):
109113 mid = "google-bert/bert-base-multilingual-cased"
110- data = get_untrained_model_with_inputs (mid , verbose = 1 )
114+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
115+ self .assertEqual (data ["task" ], "fill-mask" )
111116 self .assertIn ((data ["size" ], data ["n_weights" ]), [(428383212 , 107095803 )])
112117 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
113118 model (** inputs )
119+ model (** data ["inputs2" ])
114120 with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
115121 torch .export .export (
116122 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
@@ -119,10 +125,12 @@ def test_fill_mask(self):
119125 @hide_stdout ()
120126 def test_feature_extraction (self ):
121127 mid = "facebook/bart-base"
122- data = get_untrained_model_with_inputs (mid , verbose = 1 )
128+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
129+ self .assertEqual (data ["task" ], "feature-extraction" )
123130 self .assertIn ((data ["size" ], data ["n_weights" ]), [(557681664 , 139420416 )])
124131 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
125132 model (** inputs )
133+ model (** data ["inputs2" ])
126134 with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
127135 torch .export .export (
128136 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
@@ -131,10 +139,12 @@ def test_feature_extraction(self):
131139 @hide_stdout ()
132140 def test_text_classification (self ):
133141 mid = "Intel/bert-base-uncased-mrpc"
134- data = get_untrained_model_with_inputs (mid , verbose = 1 )
142+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
143+ self .assertEqual (data ["task" ], "text-classification" )
135144 self .assertIn ((data ["size" ], data ["n_weights" ]), [(154420232 , 38605058 )])
136145 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
137146 model (** inputs )
147+ model (** data ["inputs2" ])
138148 with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
139149 torch .export .export (
140150 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
@@ -143,10 +153,12 @@ def test_text_classification(self):
143153 @hide_stdout ()
144154 def test_sentence_similary (self ):
145155 mid = "sentence-transformers/all-MiniLM-L6-v1"
146- data = get_untrained_model_with_inputs (mid , verbose = 1 )
156+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
157+ self .assertEqual (data ["task" ], "sentence-similarity" )
147158 self .assertIn ((data ["size" ], data ["n_weights" ]), [(62461440 , 15615360 )])
148159 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
149160 model (** inputs )
161+ model (** data ["inputs2" ])
150162 with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
151163 torch .export .export (
152164 model , (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds ), strict = False
@@ -155,9 +167,11 @@ def test_sentence_similary(self):
155167 @hide_stdout ()
156168 def test_falcon_mamba_dev (self ):
157169 mid = "tiiuae/falcon-mamba-tiny-dev"
158- data = get_untrained_model_with_inputs (mid , verbose = 1 )
170+ data = get_untrained_model_with_inputs (mid , verbose = 1 , add_second_input = True )
171+ self .assertEqual (data ["task" ], "text-generation" )
159172 model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
160173 model (** inputs )
174+ model (** data ["inputs2" ])
161175 self .assertIn ((data ["size" ], data ["n_weights" ]), [(138640384 , 34660096 )])
162176 if not has_transformers ("4.55" ):
163177 raise unittest .SkipTest ("The model has control flow." )
0 commit comments