Skip to content

Commit 04baebc

Browse files
committed
add an example
1 parent 60f4304 commit 04baebc

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

CHANGELOGS.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
Change Logs
22
===========
33

4+
0.3.0
5+
+++++
6+
7+
* :pr:`23`: dummy inputs for ``image-classification``
8+
* :pr:`22`: api to create untrained model copying the architecture
9+
of the trained models and dummy inputs for them,
10+
support for ``text-generation``
11+
412
0.2.1
513
+++++
614

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
.. _l-plot-export-hub-codellama:
3+
4+
Test the export on untrained models
5+
===================================
6+
7+
Checking the exporter on a whole model takes time as it is
8+
usually big but we can create a smaller version with
9+
the same architecture. Then fix export issues on such a
10+
small model is faster.
11+
12+
codellama/CodeLlama-7b-Python-hf
13+
++++++++++++++++++++++++++++++++
14+
15+
Let's grab some information about this model.
16+
This reuses :epkg:`huggingface_hub` API.
17+
"""
18+
19+
import copy
20+
import pprint
21+
import torch
22+
from onnx_diagnostic import doc
23+
from onnx_diagnostic.helpers import string_type
24+
from onnx_diagnostic.torch_models.hghub import (
25+
get_untrained_model_with_inputs,
26+
)
27+
from onnx_diagnostic.torch_models.hghub.hub_api import (
28+
get_model_info,
29+
get_pretrained_config,
30+
task_from_id,
31+
)
32+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
33+
34+
model_id = "codellama/CodeLlama-7b-Python-hf"
35+
print("info", get_model_info(model_id))
36+
37+
# %%
38+
# The configuration.
39+
40+
print("config", get_pretrained_config(model_id))
41+
42+
# %%
43+
# The task determines the set of inputs which needs
44+
# to be created for this input.
45+
46+
print("task", task_from_id(model_id))
47+
48+
# %%
49+
# Untrained model
50+
# +++++++++++++++
51+
#
52+
# The function :func:`get_untrained_model_with_inputs
53+
# <onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs>`.
54+
# It loads the pretrained configuration, extracts the task associated
55+
# to the model and them creates random inputs and dynamic shapes
56+
# for :func:`torch.export.export`.
57+
58+
data = get_untrained_model_with_inputs(model_id, verbose=1)
59+
print("model size:", data["size"])
60+
print("number of weights:", data["n_weights"])
61+
print("fields:", set(data))
62+
63+
# %%
64+
# Inputs
65+
print("inputs:", string_type(data["inputs"], with_shape=True))
66+
67+
# %%
68+
# Dynamic Shapes
69+
print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
70+
71+
# %%
72+
# Let's check the model runs. We still needs to
73+
# copy the inputs before using the models, the cache
74+
# is usually modifed inplace.
75+
# Expected outputs can be used later to compute
76+
# discrepancies.
77+
78+
inputs_copy = copy.deepcopy(data["inputs"])
79+
model = data["model"]
80+
expected_outputs = model(**inputs_copy)
81+
82+
print("outputs:", string_type(expected_outputs, with_shape=True))
83+
84+
# %%
85+
# It works.
86+
#
87+
# Export
88+
# ++++++
89+
#
90+
# The model uses :class:`transformers.cache_utils.DynamicCache`.
91+
# It still requires patches to be exportable (control flow).
92+
# See :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
93+
94+
with bypass_export_some_errors(patch_transformers=True) as f:
95+
ep = torch.export.export(
96+
model, (), kwargs=f(data["inputs"]), dynamic_shapes=data["dynamic_shapes"]
97+
)
98+
print(ep)
99+
100+
101+
# %%
102+
103+
doc.plot_legend(
104+
"untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
105+
)

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Enlightening Examples
6464
* :ref:`l-plot-export-locale-issue`
6565
* :ref:`l-plot-tiny-llm-export`
6666
* :ref:`l-plot-tiny-llm-export-patched`
67+
* :ref:`l-plot-export-hub-codellama`
6768

6869
**Investigate ONNX models**
6970

0 commit comments

Comments
 (0)