Skip to content

Commit 659896b

Browse files
committed
fix a fzw things
1 parent c850d43 commit 659896b

File tree

7 files changed

+159
-116
lines changed

7 files changed

+159
-116
lines changed

_doc/examples/plot_export_tiny_llm.py

Lines changed: 22 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
We use the dummy example from the model page.
1616
"""
1717

18-
from typing import Any, Dict
18+
import copy
1919
import torch
2020
import transformers
2121
from onnx_diagnostic.helpers import string_type
22-
from onnx_diagnostic.cache_helpers import make_dynamic_cache
22+
from onnx_diagnostic.torch_models.llms import get_tiny_llm
2323

2424

2525
MODEL_NAME = "arnir0/Tiny-LLM"
@@ -30,21 +30,6 @@
3030
# We rewrite the forward method to print the cache dimension.
3131

3232

33-
def string_inputs(args, kwargs):
34-
def _cache(a):
35-
if len(a.key_cache):
36-
return f"n_caches={len(a.key_cache)}, shape={a.key_cache[0].shape}"
37-
return f"n_caches={len(a.key_cache)}"
38-
39-
for a in args:
40-
if isinstance(a, transformers.cache_utils.DynamicCache):
41-
return _cache(a)
42-
for k, a in kwargs.items():
43-
if isinstance(a, transformers.cache_utils.DynamicCache):
44-
return f"{k}={_cache(a)}"
45-
return "no_cache"
46-
47-
4833
def _forward_(*args, _f=None, **kwargs):
4934
assert _f is not None
5035
if not torch.compiler.is_exporting():
@@ -83,100 +68,6 @@ def _forward_(*args, _f=None, **kwargs):
8368
# Let's create an untrained model.
8469

8570

86-
def get_tiny_llm(
87-
batch_size: int = 2,
88-
input_cache: bool = True,
89-
common_dynamic_shapes: bool = True,
90-
dynamic_rope: bool = False,
91-
**kwargs,
92-
) -> Dict[str, Any]:
93-
"""
94-
Gets a non initialized model.
95-
96-
:param batch_size: batch size
97-
:param input_cache: generate data for this iteration with or without cache
98-
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
99-
:param common_dynamic_shapes: if True returns dynamic shapes as well
100-
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
101-
:return: dictionary
102-
"""
103-
import transformers
104-
105-
config = {
106-
"architectures": ["LlamaForCausalLM"],
107-
"bos_token_id": 1,
108-
"eos_token_id": 2,
109-
"hidden_act": "silu",
110-
"hidden_size": 192,
111-
"initializer_range": 0.02,
112-
"intermediate_size": 1024,
113-
"max_position_embeddings": 1024,
114-
"model_type": "llama",
115-
"num_attention_heads": 2,
116-
"num_hidden_layers": 1,
117-
"num_key_value_heads": 1,
118-
"pretraining_tp": 1,
119-
"rms_norm_eps": 1e-05,
120-
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
121-
"tie_word_embeddings": False,
122-
"torch_dtype": "float32",
123-
"transformers_version": "4.31.0.dev0",
124-
"use_cache": True,
125-
"vocab_size": 32000,
126-
}
127-
128-
config.update(**kwargs)
129-
conf = transformers.LlamaConfig(**config)
130-
model = transformers.LlamaForCausalLM(conf)
131-
model.eval()
132-
133-
# now the inputs
134-
cache_last_dim = 96
135-
sequence_length = 30
136-
sequence_length2 = 3
137-
num_key_value_heads = 1
138-
max_token_id = config["vocab_size"] - 1
139-
n_layers = config["num_hidden_layers"]
140-
141-
batch = torch.export.Dim("batch", min=1, max=1024)
142-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
143-
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
144-
145-
shapes = {
146-
"input_ids": {0: batch, 1: seq_length},
147-
"attention_mask": {
148-
0: batch,
149-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
150-
},
151-
"past_key_values": [
152-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
153-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
154-
],
155-
}
156-
inputs = dict(
157-
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
158-
torch.int64
159-
),
160-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
161-
torch.int64
162-
),
163-
past_key_values=make_dynamic_cache(
164-
[
165-
(
166-
torch.randn(
167-
batch_size, num_key_value_heads, sequence_length, cache_last_dim
168-
),
169-
torch.randn(
170-
batch_size, num_key_value_heads, sequence_length, cache_last_dim
171-
),
172-
)
173-
for i in range(n_layers)
174-
]
175-
),
176-
)
177-
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)
178-
179-
18071
# %%
18172
# Let's get the model, inputs and dynamic shapes.
18273

@@ -187,9 +78,25 @@ def get_tiny_llm(
18778
experiment["dynamic_shapes"],
18879
)
18980

81+
# %%
82+
# Before we run it, we make a copy of the inputs as the cache
83+
# get modified by the execution. Then it is no longer valid
84+
# associated with the previous input_ids and mask.
85+
cloned_inputs = copy.deepcopy(inputs)
86+
87+
19088
# %% Let's run it.
191-
expected_output = model(**inputs)
192-
print("result type", type(expected_output))
89+
print("input type", string_type(inputs, with_shape=True))
90+
91+
expected_output = untrained_model(**inputs)
92+
93+
94+
print("input after the execution", string_type(inputs, with_shape=True))
95+
print("result type", string_type(expected_output, with_shape=True))
96+
97+
ep = torch.export.export(
98+
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
99+
)
193100

194101
# %%
195102
# It works.
@@ -199,7 +106,7 @@ def get_tiny_llm(
199106

200107
try:
201108
ep = torch.export.export(
202-
untrained_model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False
109+
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
203110
)
204111
print("It worked:")
205112
print(ep)
@@ -217,7 +124,7 @@ def get_tiny_llm(
217124
# Let's use the same dummy inputs but we use the downloaded model.
218125

219126
try:
220-
ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False)
127+
ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
221128
print("It worked:")
222129
print(ep)
223130
except Exception as e:

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from experimental_experiment.ext_test_case import (
2+
from onnx_diagnostic.ext_test_case import (
33
ExtTestCase,
44
requires_torch,
55
requires_transformers,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
4+
from onnx_diagnostic.torch_models.llms import get_tiny_llm
5+
from onnx_diagnostic.helpers import string_type
6+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
7+
8+
9+
class TestLlms(ExtTestCase):
10+
def test_get_tiny_llm(self):
11+
data = get_tiny_llm()
12+
model, inputs = data["model"], data["inputs"]
13+
self.assertIn("DynamicCache", string_type(inputs))
14+
model(**inputs)
15+
16+
@ignore_warnings(UserWarning)
17+
def test_export_tiny_llm_1(self):
18+
data = get_tiny_llm()
19+
model, inputs = data["model"], data["inputs"]
20+
ep = torch.export.export(
21+
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
22+
)
23+
assert ep
24+
print(ep)
25+
26+
@ignore_warnings(UserWarning)
27+
def test_export_tiny_llm_2_bypassed(self):
28+
data = get_tiny_llm()
29+
model, inputs = data["model"], data["inputs"]
30+
with bypass_export_some_errors():
31+
ep = torch.export.export(
32+
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
33+
)
34+
assert ep
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/__init__.py

Whitespace-only changes.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Any, Dict
2+
import torch
3+
import transformers
4+
from ..cache_helpers import make_dynamic_cache
5+
6+
7+
def get_tiny_llm(
8+
batch_size: int = 2,
9+
input_cache: bool = True,
10+
dynamic_rope: bool = False,
11+
**kwargs,
12+
) -> Dict[str, Any]:
13+
"""
14+
Gets a non initialized model.
15+
16+
:param batch_size: batch size
17+
:param input_cache: generate data for this iteration with or without cache
18+
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
19+
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
20+
:return: dictionary
21+
"""
22+
config = {
23+
"architectures": ["LlamaForCausalLM"],
24+
"bos_token_id": 1,
25+
"eos_token_id": 2,
26+
"hidden_act": "silu",
27+
"hidden_size": 192,
28+
"initializer_range": 0.02,
29+
"intermediate_size": 1024,
30+
"max_position_embeddings": 1024,
31+
"model_type": "llama",
32+
"num_attention_heads": 2,
33+
"num_hidden_layers": 1,
34+
"num_key_value_heads": 1,
35+
"pretraining_tp": 1,
36+
"rms_norm_eps": 1e-05,
37+
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
38+
"tie_word_embeddings": False,
39+
"torch_dtype": "float32",
40+
"transformers_version": "4.31.0.dev0",
41+
"use_cache": True,
42+
"vocab_size": 32000,
43+
}
44+
45+
config.update(**kwargs)
46+
conf = transformers.LlamaConfig(**config)
47+
model = transformers.LlamaForCausalLM(conf)
48+
model.eval()
49+
50+
# now the inputs
51+
cache_last_dim = 96
52+
sequence_length = 30
53+
sequence_length2 = 3
54+
num_key_value_heads = 1
55+
max_token_id = config["vocab_size"] - 1
56+
n_layers = config["num_hidden_layers"]
57+
58+
batch = torch.export.Dim("batch", min=1, max=1024)
59+
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
60+
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
61+
62+
shapes = {
63+
"input_ids": {0: batch, 1: seq_length},
64+
"attention_mask": {
65+
0: batch,
66+
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
67+
},
68+
"past_key_values": [
69+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
70+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
71+
],
72+
}
73+
inputs = dict(
74+
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
75+
torch.int64
76+
),
77+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
78+
torch.int64
79+
),
80+
past_key_values=make_dynamic_cache(
81+
[
82+
(
83+
torch.randn(
84+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
85+
),
86+
torch.randn(
87+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
88+
),
89+
)
90+
for i in range(n_layers)
91+
]
92+
),
93+
)
94+
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ disable_error_code = ["union-attr"]
3434
module = ["onnx_diagnostic.torch_export_patches.*"]
3535
disable_error_code = ["arg-type", "assignment", "attr-defined", "index", "misc", "name-defined", "operator", "return-value"]
3636

37+
[[tool.mypy.overrides]]
38+
module = ["onnx_diagnostic.torch_models.*"]
39+
disable_error_code = ["attr-defined", "call-overload", "operator"]
40+
3741
[tool.ruff]
3842

3943
# Exclude a variety of commonly ignored directories.

0 commit comments

Comments
 (0)