@@ -11,9 +11,11 @@ def test_text2text_generation(self):
1111 mid = "sshleifer/tiny-marian-en-de"
1212 data = get_untrained_model_with_inputs (mid , verbose = 1 )
1313 self .assertIn ((data ["size" ], data ["n_weights" ]), [(473928 , 118482 )])
14- model , inputs = data ["model" ], data ["inputs" ]
14+ model , inputs , ds = data ["model" ], data ["inputs" ], data [ "dynamic_shapes " ]
1515 raise unittest .SkipTest (f"not working for { mid !r} " )
1616 model (** inputs )
17+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
18+ torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
1719
1820 @hide_stdout ()
1921 def test_automatic_speech_recognition (self ):
@@ -86,41 +88,50 @@ def test_imagetext2text_generation(self):
8688 mid = "HuggingFaceM4/tiny-random-idefics"
8789 data = get_untrained_model_with_inputs (mid , verbose = 1 )
8890 self .assertIn ((data ["size" ], data ["n_weights" ]), [(12742888 , 3185722 )])
89- model , inputs = data ["model" ], data ["inputs" ]
91+ model , inputs , ds = data ["model" ], data ["inputs" ], data [ "dynamic_shapes " ]
9092 model (** inputs )
93+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
94+ torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
9195
9296 @hide_stdout ()
9397 def test_fill_mask (self ):
9498 mid = "google-bert/bert-base-multilingual-cased"
9599 data = get_untrained_model_with_inputs (mid , verbose = 1 )
96100 self .assertIn ((data ["size" ], data ["n_weights" ]), [(428383212 , 107095803 )])
97- model , inputs = data ["model" ], data ["inputs" ]
101+ model , inputs , ds = data ["model" ], data ["inputs" ], data [ "dynamic_shapes " ]
98102 model (** inputs )
103+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
104+ torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
99105
100106 @hide_stdout ()
101107 def test_text_classification (self ):
102108 mid = "Intel/bert-base-uncased-mrpc"
103109 data = get_untrained_model_with_inputs (mid , verbose = 1 )
104110 self .assertIn ((data ["size" ], data ["n_weights" ]), [(154420232 , 38605058 )])
105- model , inputs = data ["model" ], data ["inputs" ]
111+ model , inputs , ds = data ["model" ], data ["inputs" ], data [ "dynamic_shapes " ]
106112 model (** inputs )
113+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
114+ torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
107115
108116 @hide_stdout ()
109117 def test_sentence_similary (self ):
110118 mid = "sentence-transformers/all-MiniLM-L6-v1"
111119 data = get_untrained_model_with_inputs (mid , verbose = 1 )
112120 self .assertIn ((data ["size" ], data ["n_weights" ]), [(62461440 , 15615360 )])
113- model , inputs = data ["model" ], data ["inputs" ]
121+ model , inputs , ds = data ["model" ], data ["inputs" ], data [ "dynamic_shapes " ]
114122 model (** inputs )
123+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
124+ torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
115125
116126 @hide_stdout ()
117127 def test_falcon_mamba_dev (self ):
118128 mid = "tiiuae/falcon-mamba-tiny-dev"
119129 data = get_untrained_model_with_inputs (mid , verbose = 1 )
120- model , inputs = data ["model" ], data ["inputs" ]
121- print (self .string_type (inputs , with_shape = True ))
130+ model , inputs , ds = data ["model" ], data ["inputs" ], data ["dynamic_shapes" ]
122131 model (** inputs )
123132 self .assertIn ((data ["size" ], data ["n_weights" ]), [(138640384 , 34660096 )])
133+ with bypass_export_some_errors (patch_transformers = True , verbose = 10 ):
134+ torch .export .export (model , (), kwargs = inputs , dynamic_shapes = ds , strict = False )
124135
125136
126137if __name__ == "__main__" :
0 commit comments