@@ -181,7 +181,7 @@ def get_tiny_llm(
181181# Let's get the model, inputs and dynamic shapes.
182182
183183experiment = get_tiny_llm ()
184- model , inputs , dynamic_shapes = (
184+ untrained_model , inputs , dynamic_shapes = (
185185 experiment ["model" ],
186186 experiment ["inputs" ],
187187 experiment ["dynamic_shapes" ],
@@ -198,7 +198,26 @@ def get_tiny_llm(
198198# +++++++++++++++
199199
200200try :
201- ep = torch .export .export (model , (), inputs , dynamic_shapes = dynamic_shapes )
201+ ep = torch .export .export (
202+ untrained_model , (), inputs , dynamic_shapes = dynamic_shapes , strict = False
203+ )
204+ print ("It worked:" )
205+ print (ep )
206+ except Exception as e :
207+ # To work, it needs at least PRs:
208+ # * https://github.com/huggingface/transformers/pull/36311
209+ # * https://github.com/huggingface/transformers/pull/36652
210+ print ("It failed:" , e )
211+
212+
213+ # %%
214+ # Back to the original model
215+ # ++++++++++++++++++++++++++
216+ #
217+ # Let's use the same dummy inputs but we use the downloaded model.
218+
219+ try :
220+ ep = torch .export .export (model , (), inputs , dynamic_shapes = dynamic_shapes , strict = False )
202221 print ("It worked:" )
203222 print (ep )
204223except Exception as e :
0 commit comments