Skip to content

Commit aae5d21

Browse files
committed
add example
1 parent 3600e70 commit aae5d21

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""
2+
Export LLM with dynamic shapes
3+
==============================
4+
5+
We focus on the model
6+
`Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
7+
To avoid downloading any weigths, we write a function creating a
8+
random model based on the same architecture.
9+
10+
Guess the cache dimension
11+
+++++++++++++++++++++++++
12+
13+
The first step is to guess the dummy inputs.
14+
Let's use the true model for that.
15+
We use the dummy example from the model page.
16+
"""
17+
18+
from typing import Any, Dict
19+
import torch
20+
import transformers
21+
from onnx_diagnostic.helpers import string_type
22+
from onnx_diagnostic.cache_helpers import make_dynamic_cache
23+
24+
25+
MODEL_NAME = "arnir0/Tiny-LLM"
26+
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
27+
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
28+
29+
# %%
30+
# We rewrite the forward method to print the cache dimension.
31+
32+
33+
def string_inputs(args, kwargs):
34+
def _cache(a):
35+
if len(a.key_cache):
36+
return f"n_caches={len(a.key_cache)}, shape={a.key_cache[0].shape}"
37+
return f"n_caches={len(a.key_cache)}"
38+
39+
for a in args:
40+
if isinstance(a, transformers.cache_utils.DynamicCache):
41+
return _cache(a)
42+
for k, a in kwargs.items():
43+
if isinstance(a, transformers.cache_utils.DynamicCache):
44+
return f"{k}={_cache(a)}"
45+
return "no_cache"
46+
47+
48+
def _forward_(*args, _f=None, **kwargs):
49+
assert _f is not None
50+
if not torch.compiler.is_exporting():
51+
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
52+
res = _f(*args, **kwargs)
53+
if not torch.compiler.is_exporting():
54+
print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
55+
return res
56+
57+
58+
keep_model_forward = model.forward
59+
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
60+
*args, _f=_f, **kwargs
61+
)
62+
63+
# %%
64+
# Let's run the model.
65+
prompt = "Continue: it rains..."
66+
inputs = tokenizer.encode(prompt, return_tensors="pt")
67+
68+
outputs = model.generate(
69+
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
70+
)
71+
72+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
73+
print(generated_text)
74+
75+
# %%
76+
# Let's restore the forward as it was.
77+
model.forward = keep_model_forward
78+
79+
# %%
80+
# The model creation
81+
# ++++++++++++++++++
82+
#
83+
# Let's create an untrained model.
84+
85+
86+
def get_tiny_llm(
87+
batch_size: int = 2,
88+
input_cache: bool = True,
89+
common_dynamic_shapes: bool = True,
90+
dynamic_rope: bool = False,
91+
**kwargs,
92+
) -> Dict[str, Any]:
93+
"""
94+
Gets a non initialized model.
95+
96+
:param batch_size: batch size
97+
:param input_cache: generate data for this iteration with or without cache
98+
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
99+
:param common_dynamic_shapes: if True returns dynamic shapes as well
100+
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
101+
:return: dictionary
102+
"""
103+
import transformers
104+
105+
config = {
106+
"architectures": ["LlamaForCausalLM"],
107+
"bos_token_id": 1,
108+
"eos_token_id": 2,
109+
"hidden_act": "silu",
110+
"hidden_size": 192,
111+
"initializer_range": 0.02,
112+
"intermediate_size": 1024,
113+
"max_position_embeddings": 1024,
114+
"model_type": "llama",
115+
"num_attention_heads": 2,
116+
"num_hidden_layers": 1,
117+
"num_key_value_heads": 1,
118+
"pretraining_tp": 1,
119+
"rms_norm_eps": 1e-05,
120+
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
121+
"tie_word_embeddings": False,
122+
"torch_dtype": "float32",
123+
"transformers_version": "4.31.0.dev0",
124+
"use_cache": True,
125+
"vocab_size": 32000,
126+
}
127+
128+
config.update(**kwargs)
129+
conf = transformers.LlamaConfig(**config)
130+
model = transformers.LlamaForCausalLM(conf)
131+
model.eval()
132+
133+
# now the inputs
134+
cache_last_dim = 96
135+
sequence_length = 30
136+
sequence_length2 = 3
137+
num_key_value_heads = 1
138+
max_token_id = config["vocab_size"] - 1
139+
n_layers = config["num_hidden_layers"]
140+
141+
batch = torch.export.Dim("batch", min=1, max=1024)
142+
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
143+
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
144+
145+
shapes = {
146+
"input_ids": {0: batch, 1: seq_length},
147+
"attention_mask": {
148+
0: batch,
149+
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
150+
},
151+
"past_key_values": [
152+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
153+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
154+
],
155+
}
156+
inputs = dict(
157+
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
158+
torch.int64
159+
),
160+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
161+
torch.int64
162+
),
163+
past_key_values=make_dynamic_cache(
164+
[
165+
(
166+
torch.randn(
167+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
168+
),
169+
torch.randn(
170+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
171+
),
172+
)
173+
for i in range(n_layers)
174+
]
175+
),
176+
)
177+
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)
178+
179+
180+
# %%
181+
# Let's get the model, inputs and dynamic shapes.
182+
183+
experiment = get_tiny_llm()
184+
model, inputs, dynamic_shapes = (
185+
experiment["model"],
186+
experiment["inputs"],
187+
experiment["dynamic_shapes"],
188+
)
189+
190+
# %% Let's run it.
191+
expected_output = model(**inputs)
192+
print("result type", type(expected_output))
193+
194+
# %%
195+
# It works.
196+
#
197+
# ExportedProgram
198+
# +++++++++++++++
199+
200+
try:
201+
ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes)
202+
print("It worked:")
203+
print(ep)
204+
except Exception as e:
205+
# To work, it needs at least PRs:
206+
# * https://github.com/huggingface/transformers/pull/36311
207+
# * https://github.com/huggingface/transformers/pull/36652
208+
print("It failed:", e)

0 commit comments

Comments
 (0)