Skip to content

Commit 7e5e798

Browse files
committed
add phi2
1 parent 44cf375 commit 7e5e798

File tree

7 files changed

+217
-5
lines changed

7 files changed

+217
-5
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
.. _l-plot-export_tiny_phi2:
3+
4+
Untrained microsoft/phi-2
5+
=========================
6+
7+
:epkg:`microsoft/phi-2` is not a big models but still quite big
8+
when it comes to write unittest. Function
9+
:func:`onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs`
10+
can be used to create a reduced untrained version of a model coming from
11+
:epkg:`HuggingFace`. It downloads the configuration from the website
12+
but creates a dummy model with 1 or 2 hidden layers in order to reduce
13+
the size and get a fast execution. The goal is usually to test
14+
the export or to compare performance. The relevance does not matter.
15+
16+
Create the dummy model
17+
++++++++++++++++++++++
18+
"""
19+
20+
import copy
21+
import pprint
22+
import warnings
23+
import torch
24+
import onnxruntime
25+
from onnx_diagnostic import doc
26+
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
27+
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
28+
from onnx_diagnostic.helpers.ort_session import make_feeds
29+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
30+
from onnx_diagnostic.torch_models.hghub import (
31+
get_untrained_model_with_inputs,
32+
)
33+
34+
warnings.simplefilter("ignore")
35+
36+
37+
data = get_untrained_model_with_inputs("microsoft/phi-2")
38+
untrained_model, inputs, dynamic_shapes, config, size, n_weights = (
39+
data["model"],
40+
data["inputs"],
41+
data["dynamic_shapes"],
42+
data["configuration"],
43+
data["size"],
44+
data["n_weights"],
45+
)
46+
47+
print(f"model {size / 2**10:1.3f} Kb with {n_weights} parameters.")
48+
# %%
49+
# The original model has 2.7 billion parameters. It was divided by more than 10.
50+
# Let's see the configuration.
51+
print(config)
52+
53+
54+
# %%
55+
# Inputs:
56+
57+
print(string_type(inputs, with_shape=True))
58+
59+
# %%
60+
# With min/max values.
61+
print(string_type(inputs, with_shape=True, with_min_max=True))
62+
63+
# %%
64+
# And the dynamic shapes
65+
pprint.pprint(dynamic_shapes)
66+
67+
# %%
68+
# We execute the model to produce expected outputs.
69+
expected = untrained_model(**copy.deepcopy(inputs))
70+
print(f"expected: {string_type(expected, with_shape=True, with_min_max=True)}")
71+
72+
73+
# %%
74+
# Export
75+
# ++++++
76+
77+
78+
with bypass_export_some_errors(patch_transformers=True) as modificator:
79+
80+
# Unnecessary steps but useful in case of an error
81+
# We check the cache is registered.
82+
assert is_cache_dynamic_registered()
83+
84+
# We check there is no discrepancies when the cache is applied.
85+
d = max_diff(expected, untrained_model(**copy.deepcopy(inputs)))
86+
assert (
87+
d["abs"] < 1e-5
88+
), f"The model with patches produces different outputs: {string_diff(d)}"
89+
90+
# Then we export.
91+
ep = torch.export.export(
92+
untrained_model,
93+
(),
94+
kwargs=modificator(copy.deepcopy(inputs)),
95+
dynamic_shapes=dynamic_shapes,
96+
strict=False, # mandatory for torch==2.6
97+
)
98+
99+
# We check the exported program produces the same results as well.
100+
d = max_diff(expected, ep.module()(**copy.deepcopy(inputs)))
101+
assert d["abs"] < 1e-5, f"The exported model different outputs: {string_diff(d)}"
102+
103+
# %%
104+
# Export to ONNX
105+
# ++++++++++++++
106+
#
107+
# The export works. We can export to ONNX now.
108+
# Patches are still needed because the export
109+
# applies :meth:`torch.export.ExportedProgram.run_decompositions`
110+
# may export local pieces of the model again.
111+
112+
with bypass_export_some_errors(patch_transformers=True):
113+
epo = torch.onnx.export(
114+
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
115+
)
116+
117+
# %%
118+
# We can save it.
119+
epo.save("plot_export_tiny_phi2.onnx", external_data=True)
120+
121+
# Or directly get the :class:`onnx.ModelProto`.
122+
onx = epo.model_proto
123+
124+
125+
# %%
126+
# Discrepancies
127+
# +++++++++++++
128+
#
129+
# The we check the conversion to ONNX.
130+
# Let's make sure the ONNX model produces the same outputs.
131+
# It takes flatten inputs.
132+
133+
feeds = make_feeds(onx, copy.deepcopy(inputs), use_numpy=True)
134+
135+
print(f"torch inputs: {string_type(inputs)}")
136+
print(f"onxrt inputs: {string_type(feeds)}")
137+
138+
# %%
139+
# We then create a :class:`onnxruntime.InferenceSession`.
140+
141+
sess = onnxruntime.InferenceSession(
142+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
143+
)
144+
145+
# %%
146+
# Let's run.
147+
got = sess.run(None, feeds)
148+
149+
# %%
150+
# And finally the discrepancies.
151+
152+
diff = max_diff(expected, got, flatten=True)
153+
print(f"onnx discrepancies: {string_diff(diff)}")
154+
155+
# %%
156+
# It looks good.
157+
158+
# %%
159+
doc.plot_legend("untrained smaller\nmicrosoft/phi-2", "torch.onnx.export", "green")

_unittests/ut_helpers/test_helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ def test_flatten(self):
174174
d = string_diff(diff)
175175
self.assertIsInstance(d, str)
176176

177+
def test_flatten_cache(self):
178+
cache = make_dynamic_cache([(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)])
179+
flat = flatten_object(cache, drop_keys=True)
180+
self.assertEqual(string_type(flat), "(T1r4,T1r4)")
181+
cache = dict(
182+
cache=make_dynamic_cache(
183+
[(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)]
184+
)
185+
)
186+
flat = flatten_object(cache, drop_keys=True)
187+
self.assertEqual(string_type(flat), "#2[T1r4,T1r4]")
188+
177189
@hide_stdout()
178190
def test_max_diff_verbose(self):
179191
inputs = (

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
import transformers.cache_utils
66

77

8-
def is_cache_dynamic_registered() -> bool:
8+
def is_cache_dynamic_registered(fast: bool = False) -> bool:
99
"""
1010
Tells class :class:`transformers.cache_utils.DynamicCache` can be
1111
serialized and deserialized. Only then, :func:`torch.export.export`
1212
can export a model.
13+
14+
:param fast: if True, do not check the serialization is ok as well
15+
:return: result
1316
"""
17+
if fast:
18+
return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES
1419
bsize, nheads, slen, dim = 2, 4, 3, 7
1520
cache = make_dynamic_cache(
1621
[

onnx_diagnostic/helpers/ort_session.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from torch._C import _from_dlpack
77
import onnxruntime
88
from onnxruntime.capi import _pybind_state as ORTC
9-
from .helper import size_type
9+
from .cache_helper import is_cache_dynamic_registered
10+
from .helper import size_type, string_type, flatten_object
1011
from .onnx_helper import (
1112
torch_dtype_to_onnx_dtype,
1213
onnx_dtype_to_np_dtype,
@@ -17,6 +18,34 @@
1718
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
1819

1920

21+
def make_feeds(
22+
proto: onnx.ModelProto, inputs: Any, use_numpy: bool = True
23+
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
24+
"""
25+
Serializes the inputs to produce feeds expected
26+
by :class:`onnxruntime.InferenceSession`.
27+
28+
:param proto: onnx model
29+
:param inputs: any kind of inputs
30+
:param use_numpy: if True, converts torch tensors into numpy arrays
31+
:return: feeds dictionary
32+
"""
33+
flat = flatten_object(inputs, drop_keys=True)
34+
assert (
35+
not all(isinstance(obj, torch.Tensor) for obj in flat)
36+
or not is_cache_dynamic_registered(fast=True)
37+
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
38+
), (
39+
f"Unexpected number of flattened objects, "
40+
f"{string_type(flat, with_shape=True, limit=20)} != "
41+
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True,limit=20)}"
42+
)
43+
if use_numpy:
44+
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
45+
names = [i.name for i in proto.graph.input]
46+
return dict(zip(names, flat))
47+
48+
2049
class _InferenceSession:
2150

2251
@classmethod

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
6868
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
6969
# Cache serialization: to be moved into appropriate packages
7070
import torch
71+
import transformers
72+
import packaging.version as pv
7173

7274
try:
7375
from transformers.cache_utils import DynamicCache
@@ -108,7 +110,9 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
108110
# torch.fx._pytree.register_pytree_flatten_spec(
109111
# DynamicCache, _flatten_dynamic_cache_for_fx)
110112
# so we remove it anyway
111-
if DynamicCache in torch.fx._pytree.SUPPORTED_NODES:
113+
if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(
114+
transformers.__version__
115+
) >= pv.Version("2.7"):
112116
if verbose:
113117
print("[_register_cache_serialization] DynamicCache is unregistered first.")
114118
_unregister(DynamicCache)

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _patch_make_causal_mask(
4747
if sys.version_info[:2] <= (3, 11):
4848

4949
@dataclass
50-
class patched_AttentionMaskConverter:
50+
class kkpatched_AttentionMaskConverter:
5151
"""
5252
Patches
5353
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
@@ -72,7 +72,7 @@ def _make_causal_mask(
7272
else:
7373

7474
@dataclass
75-
class patched_AttentionMaskConverter:
75+
class kkpatched_AttentionMaskConverter:
7676
"""
7777
Patches
7878
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,9 @@ def get_untrained_model_with_inputs(
293293
kwargs.update(inputs_kwargs)
294294

295295
model = getattr(transformers, arch)(config)
296+
# This line is important. Some models may produce different
297+
# outputs even with the same inputs in training mode.
298+
model.eval()
296299
res = fct(model, config, **kwargs)
297300
res["input_kwargs"] = kwargs
298301
res["model_kwargs"] = mkwargs

0 commit comments

Comments
 (0)