Skip to content

Commit 2bbcc75

Browse files
committed
copy
1 parent 5cb5006 commit 2bbcc75

File tree

4 files changed

+119
-0
lines changed

4 files changed

+119
-0
lines changed

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.3.0
55
+++++
66

7+
* :pr:`24`: dummy inputs for ``text2text-generation``, add new function
8+
``convert_dynamic_axes_into_dynamic_shapes`` to convert dynamic axes
9+
into dynamic shapes, add support for ``T5ForConditionalGeneration``
710
* :pr:`23`: dummy inputs for ``image-classification``
811
* :pr:`22`: api to create untrained model copying the architecture
912
of the trained models and dummy inputs for them,

_doc/api/torch_export_patches/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ onnx_diagnostic.torch_export_patches
66
:caption: submodules
77

88
patches/index
9+
patch_inputs
10+
911

1012
.. automodule:: onnx_diagnostic.torch_export_patches
1113
:members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_inputs
3+
=================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_inputs
6+
:members:
7+
:no-undoc-members:

onnx_diagnostic/torch_export_patches/__init__.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,110 @@
22
bypass_export_some_errors,
33
register_additional_serialization_functions,
44
)
5+
6+
"""
7+
-- Missing dependencies --
8+
9+
def is_torchdynamo_exporting() -> bool:
10+
"Tells if torch is exporting a model."
11+
import torch
12+
13+
if not hasattr(torch.compiler, "is_exporting"):
14+
# torch.compiler.is_exporting requires torch>=2.7
15+
return False
16+
17+
try:
18+
return torch.compiler.is_exporting()
19+
except Exception:
20+
try:
21+
import torch._dynamo as dynamo
22+
23+
return dynamo.is_exporting() # type: ignore
24+
except Exception:
25+
return False
26+
27+
28+
def string_type(anything, **args):
29+
# too long
30+
# from onnx_diagnostic.helpers import string_type
31+
return str(anything)
32+
33+
34+
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
35+
36+
def make_dynamic_cache(
37+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
38+
) -> transformers.cache_utils.DynamicCache:
39+
'''
40+
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
41+
This version is valid for ``transformers >= 4.50``.
42+
43+
:param key_value_pairs: list of pairs of (key, values)
44+
:return: :class:`transformers.cache_utils.DynamicCache`
45+
46+
Example:
47+
48+
::
49+
50+
n_layers = 2
51+
bsize, nheads, slen, dim = 2, 4, 3, 7
52+
53+
past_key_values = make_dynamic_cache(
54+
[
55+
(
56+
torch.randn(bsize, nheads, slen, dim),
57+
torch.randn(bsize, nheads, slen, dim),
58+
)
59+
for i in range(n_layers)
60+
]
61+
)
62+
print(string_type(past_key_values, with_shape=True))
63+
'''
64+
return transformers.cache_utils.DynamicCache(key_value_pairs)
65+
66+
else:
67+
68+
def make_dynamic_cache(
69+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
70+
) -> transformers.cache_utils.DynamicCache:
71+
'''
72+
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
73+
This version is valid for ``transformers < 4.50``.
74+
75+
:param key_value_pairs: list of pairs of (key, values)
76+
:return: :class:`transformers.cache_utils.DynamicCache`
77+
78+
Example:
79+
80+
::
81+
82+
n_layers = 2
83+
bsize, nheads, slen, dim = 2, 4, 3, 7
84+
85+
past_key_values = make_dynamic_cache(
86+
[
87+
(
88+
torch.randn(bsize, nheads, slen, dim),
89+
torch.randn(bsize, nheads, slen, dim),
90+
)
91+
for i in range(n_layers)
92+
]
93+
)
94+
print(string_type(past_key_values, with_shape=True))
95+
'''
96+
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs))
97+
for i, (key, value) in enumerate(key_value_pairs):
98+
cache.update(key, value, i)
99+
return cache
100+
101+
102+
def make_encoder_decoder_cache(
103+
self_attention_cache: transformers.cache_utils.DynamicCache,
104+
cross_attention_cache: transformers.cache_utils.DynamicCache,
105+
) -> transformers.cache_utils.EncoderDecoderCache:
106+
"Creates an EncoderDecoderCache."
107+
return transformers.cache_utils.EncoderDecoderCache(
108+
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
109+
)
110+
111+
"""

0 commit comments

Comments
 (0)