1515We use the dummy example from the model page.
1616"""
1717
18- from typing import Any , Dict
18+ import copy
1919import torch
2020import transformers
2121from onnx_diagnostic .helpers import string_type
22- from onnx_diagnostic .cache_helpers import make_dynamic_cache
22+ from onnx_diagnostic .torch_models . llms import get_tiny_llm
2323
2424
2525MODEL_NAME = "arnir0/Tiny-LLM"
3030# We rewrite the forward method to print the cache dimension.
3131
3232
33- def string_inputs (args , kwargs ):
34- def _cache (a ):
35- if len (a .key_cache ):
36- return f"n_caches={ len (a .key_cache )} , shape={ a .key_cache [0 ].shape } "
37- return f"n_caches={ len (a .key_cache )} "
38-
39- for a in args :
40- if isinstance (a , transformers .cache_utils .DynamicCache ):
41- return _cache (a )
42- for k , a in kwargs .items ():
43- if isinstance (a , transformers .cache_utils .DynamicCache ):
44- return f"{ k } ={ _cache (a )} "
45- return "no_cache"
46-
47-
4833def _forward_ (* args , _f = None , ** kwargs ):
4934 assert _f is not None
5035 if not torch .compiler .is_exporting ():
@@ -83,100 +68,6 @@ def _forward_(*args, _f=None, **kwargs):
8368# Let's create an untrained model.
8469
8570
86- def get_tiny_llm (
87- batch_size : int = 2 ,
88- input_cache : bool = True ,
89- common_dynamic_shapes : bool = True ,
90- dynamic_rope : bool = False ,
91- ** kwargs ,
92- ) -> Dict [str , Any ]:
93- """
94- Gets a non initialized model.
95-
96- :param batch_size: batch size
97- :param input_cache: generate data for this iteration with or without cache
98- :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
99- :param common_dynamic_shapes: if True returns dynamic shapes as well
100- :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
101- :return: dictionary
102- """
103- import transformers
104-
105- config = {
106- "architectures" : ["LlamaForCausalLM" ],
107- "bos_token_id" : 1 ,
108- "eos_token_id" : 2 ,
109- "hidden_act" : "silu" ,
110- "hidden_size" : 192 ,
111- "initializer_range" : 0.02 ,
112- "intermediate_size" : 1024 ,
113- "max_position_embeddings" : 1024 ,
114- "model_type" : "llama" ,
115- "num_attention_heads" : 2 ,
116- "num_hidden_layers" : 1 ,
117- "num_key_value_heads" : 1 ,
118- "pretraining_tp" : 1 ,
119- "rms_norm_eps" : 1e-05 ,
120- "rope_scaling" : {"rope_type" : "dynamic" , "factor" : 10.0 } if dynamic_rope else None ,
121- "tie_word_embeddings" : False ,
122- "torch_dtype" : "float32" ,
123- "transformers_version" : "4.31.0.dev0" ,
124- "use_cache" : True ,
125- "vocab_size" : 32000 ,
126- }
127-
128- config .update (** kwargs )
129- conf = transformers .LlamaConfig (** config )
130- model = transformers .LlamaForCausalLM (conf )
131- model .eval ()
132-
133- # now the inputs
134- cache_last_dim = 96
135- sequence_length = 30
136- sequence_length2 = 3
137- num_key_value_heads = 1
138- max_token_id = config ["vocab_size" ] - 1
139- n_layers = config ["num_hidden_layers" ]
140-
141- batch = torch .export .Dim ("batch" , min = 1 , max = 1024 )
142- seq_length = torch .export .Dim ("seq_length" , min = 1 , max = 4096 )
143- cache_length = torch .export .Dim ("cache_length" , min = 1 , max = 4096 )
144-
145- shapes = {
146- "input_ids" : {0 : batch , 1 : seq_length },
147- "attention_mask" : {
148- 0 : batch ,
149- 1 : torch .export .Dim .DYNAMIC , # cache_length + seq_length
150- },
151- "past_key_values" : [
152- [{0 : batch , 2 : cache_length } for _ in range (n_layers )],
153- [{0 : batch , 2 : cache_length } for _ in range (n_layers )],
154- ],
155- }
156- inputs = dict (
157- input_ids = torch .randint (0 , max_token_id , (batch_size , sequence_length2 )).to (
158- torch .int64
159- ),
160- attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
161- torch .int64
162- ),
163- past_key_values = make_dynamic_cache (
164- [
165- (
166- torch .randn (
167- batch_size , num_key_value_heads , sequence_length , cache_last_dim
168- ),
169- torch .randn (
170- batch_size , num_key_value_heads , sequence_length , cache_last_dim
171- ),
172- )
173- for i in range (n_layers )
174- ]
175- ),
176- )
177- return dict (inputs = inputs , model = model , dynamic_shapes = shapes )
178-
179-
18071# %%
18172# Let's get the model, inputs and dynamic shapes.
18273
@@ -187,9 +78,25 @@ def get_tiny_llm(
18778 experiment ["dynamic_shapes" ],
18879)
18980
81+ # %%
82+ # Before we run it, we make a copy of the inputs as the cache
83+ # get modified by the execution. Then it is no longer valid
84+ # associated with the previous input_ids and mask.
85+ cloned_inputs = copy .deepcopy (inputs )
86+
87+
19088# %% Let's run it.
191- expected_output = model (** inputs )
192- print ("result type" , type (expected_output ))
89+ print ("input type" , string_type (inputs , with_shape = True ))
90+
91+ expected_output = untrained_model (** inputs )
92+
93+
94+ print ("input after the execution" , string_type (inputs , with_shape = True ))
95+ print ("result type" , string_type (expected_output , with_shape = True ))
96+
97+ ep = torch .export .export (
98+ untrained_model , (), kwargs = cloned_inputs , dynamic_shapes = dynamic_shapes
99+ )
193100
194101# %%
195102# It works.
@@ -199,7 +106,7 @@ def get_tiny_llm(
199106
200107try :
201108 ep = torch .export .export (
202- untrained_model , (), inputs , dynamic_shapes = dynamic_shapes , strict = False
109+ untrained_model , (), kwargs = cloned_inputs , dynamic_shapes = dynamic_shapes
203110 )
204111 print ("It worked:" )
205112 print (ep )
@@ -217,7 +124,7 @@ def get_tiny_llm(
217124# Let's use the same dummy inputs but we use the downloaded model.
218125
219126try :
220- ep = torch .export .export (model , (), inputs , dynamic_shapes = dynamic_shapes , strict = False )
127+ ep = torch .export .export (model , (), kwargs = cloned_inputs , dynamic_shapes = dynamic_shapes )
221128 print ("It worked:" )
222129 print (ep )
223130except Exception as e :
0 commit comments