Skip to content

Commit 20e59ec

Browse files
committed
changes
1 parent d5ea218 commit 20e59ec

File tree

11 files changed

+91
-136
lines changed

11 files changed

+91
-136
lines changed

CHANGELOGS.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,
8+
use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes
9+
for a specific models, it is a valid definition for ``torch.onnx.export``
10+
which can reuse the names
711
* :pr:`55`: add support for text-classification
812
* :pr:`54`: add support for fill-mask, refactoring
913
* :pr:`52`: add support for zero-shot-image-classification

_doc/examples/plot_export_hub_codellama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
task_from_id,
3131
)
3232
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
33+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3334

3435
model_id = "codellama/CodeLlama-7b-Python-hf"
3536
print("info", get_model_info(model_id))
@@ -96,7 +97,7 @@
9697
model,
9798
(),
9899
kwargs=f(data["inputs"]),
99-
dynamic_shapes=data["dynamic_shapes"],
100+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
100101
strict=False,
101102
)
102103
print(ep)

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
2828
from onnx_diagnostic.helpers.rt_helper import make_feeds
2929
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
30+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3031
from onnx_diagnostic.torch_models.hghub import (
3132
get_untrained_model_with_inputs,
3233
)
@@ -92,7 +93,7 @@
9293
untrained_model,
9394
(),
9495
kwargs=modificator(copy.deepcopy(inputs)),
95-
dynamic_shapes=dynamic_shapes,
96+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
9697
strict=False, # mandatory for torch==2.6
9798
)
9899

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,17 +576,20 @@ def test_couple_input_ds_cache(self):
576576
Cls(
577577
(),
578578
kwargs,
579-
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
579+
{
580+
"A": ds_batch,
581+
"B": (ds_batch, [[ds_batch, ds_batch], [ds_batch, ds_batch]]),
582+
},
580583
).invalid_dimensions_for_export(),
581584
)
582585
self.assertEqual(
583-
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
586+
{"B": (None, [[None, {2: "d=[1]"}], [None, {2: "d=[1]"}]])},
584587
Cls(
585588
(),
586589
kwargs,
587590
{
588591
"A": ds_batch,
589-
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
592+
"B": (ds_batch, [[ds_batch, ds_batch_seq], [ds_batch, ds_batch_seq]]),
590593
},
591594
).invalid_dimensions_for_export(),
592595
)

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from onnx_diagnostic.torch_export_patches.patch_inputs import (
88
convert_dynamic_axes_into_dynamic_shapes,
99
)
10+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1011

1112

1213
class TestCacheHelpers(ExtTestCase):
@@ -59,8 +60,9 @@ def test_replace_by(self):
5960
)
6061
self.assertEqual(dynamic_shapes, nds)
6162

62-
cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes)
63-
res = cpl.replace_string_by()
63+
with bypass_export_some_errors(patch_transformers=True):
64+
cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes)
65+
res = cpl.replace_string_by()
6466
dsc = res["past_key_values"]
6567
self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc)
6668

_unittests/ut_torch_export_patches/test_patch_inputs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.torch_export_patches.patch_inputs import (
77
convert_dynamic_axes_into_dynamic_shapes,
8+
use_dyn_not_str,
89
)
910

1011

@@ -111,6 +112,26 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
111112
string_type(res[1], with_shape=True),
112113
)
113114

115+
def test_use_dyn_not_str(self):
116+
batch = torch.export.Dim("batch")
117+
dynamic_shapes = dict(
118+
input_ids={0: batch, 1: "seq"},
119+
attention_mask={0: batch, 1: "seq"},
120+
position_ids={0: batch, 1: "seq"},
121+
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
122+
)
123+
res = use_dyn_not_str(dynamic_shapes)
124+
DYN = torch.export.Dim.DYNAMIC
125+
self.assertEqual(
126+
dict(
127+
input_ids={0: batch, 1: DYN},
128+
attention_mask={0: batch, 1: DYN},
129+
position_ids={0: batch, 1: DYN},
130+
past_key_values=[[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]],
131+
),
132+
res,
133+
)
134+
114135

115136
if __name__ == "__main__":
116137
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def get_parser_validate() -> ArgumentParser:
302302

303303
def _cmd_validate(argv: List[Any]):
304304
from .helpers import string_type
305-
from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean
305+
from .torch_models.test_helper import get_inputs_for_task, validate_model
306306
from .tasks import supported_tasks
307307

308308
parser = get_parser_validate()
@@ -320,7 +320,7 @@ def _cmd_validate(argv: List[Any]):
320320
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
321321
print("-- dynamic_shapes")
322322
for k, v in data["dynamic_shapes"].items():
323-
print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
323+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
324324
else:
325325
# Let's skip any invalid combination if known to be unsupported
326326
if (

onnx_diagnostic/helpers/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def string_type(
256256
print(f"[string_type] L:{type(obj)}")
257257
return f"{{...}}#{len(obj)}" if with_shape else "{...}"
258258
# dict
259-
if isinstance(obj, dict):
259+
if isinstance(obj, dict) and type(obj) is dict:
260260
if len(obj) == 0:
261261
if verbose:
262262
print(f"[string_type] M:{type(obj)}")
@@ -276,7 +276,7 @@ def string_type(
276276
)
277277
for v in obj.values()
278278
):
279-
# This is dyanmic shapes
279+
# This is dynamic shapes
280280
rows = []
281281
for k, v in obj.items():
282282
if isinstance(v, str):

onnx_diagnostic/torch_export_patches/__init__.py

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2,110 +2,3 @@
22
bypass_export_some_errors,
33
register_additional_serialization_functions,
44
)
5-
6-
"""
7-
-- Missing dependencies --
8-
9-
def is_torchdynamo_exporting() -> bool:
10-
"Tells if torch is exporting a model."
11-
import torch
12-
13-
if not hasattr(torch.compiler, "is_exporting"):
14-
# torch.compiler.is_exporting requires torch>=2.7
15-
return False
16-
17-
try:
18-
return torch.compiler.is_exporting()
19-
except Exception:
20-
try:
21-
import torch._dynamo as dynamo
22-
23-
return dynamo.is_exporting() # type: ignore
24-
except Exception:
25-
return False
26-
27-
28-
def string_type(anything, **args):
29-
# too long
30-
# from onnx_diagnostic.helpers import string_type
31-
return str(anything)
32-
33-
34-
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
35-
36-
def make_dynamic_cache(
37-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
38-
) -> transformers.cache_utils.DynamicCache:
39-
'''
40-
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
41-
This version is valid for ``transformers >= 4.50``.
42-
43-
:param key_value_pairs: list of pairs of (key, values)
44-
:return: :class:`transformers.cache_utils.DynamicCache`
45-
46-
Example:
47-
48-
::
49-
50-
n_layers = 2
51-
bsize, nheads, slen, dim = 2, 4, 3, 7
52-
53-
past_key_values = make_dynamic_cache(
54-
[
55-
(
56-
torch.randn(bsize, nheads, slen, dim),
57-
torch.randn(bsize, nheads, slen, dim),
58-
)
59-
for i in range(n_layers)
60-
]
61-
)
62-
print(string_type(past_key_values, with_shape=True))
63-
'''
64-
return transformers.cache_utils.DynamicCache(key_value_pairs)
65-
66-
else:
67-
68-
def make_dynamic_cache(
69-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
70-
) -> transformers.cache_utils.DynamicCache:
71-
'''
72-
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
73-
This version is valid for ``transformers < 4.50``.
74-
75-
:param key_value_pairs: list of pairs of (key, values)
76-
:return: :class:`transformers.cache_utils.DynamicCache`
77-
78-
Example:
79-
80-
::
81-
82-
n_layers = 2
83-
bsize, nheads, slen, dim = 2, 4, 3, 7
84-
85-
past_key_values = make_dynamic_cache(
86-
[
87-
(
88-
torch.randn(bsize, nheads, slen, dim),
89-
torch.randn(bsize, nheads, slen, dim),
90-
)
91-
for i in range(n_layers)
92-
]
93-
)
94-
print(string_type(past_key_values, with_shape=True))
95-
'''
96-
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs))
97-
for i, (key, value) in enumerate(key_value_pairs):
98-
cache.update(key, value, i)
99-
return cache
100-
101-
102-
def make_encoder_decoder_cache(
103-
self_attention_cache: transformers.cache_utils.DynamicCache,
104-
cross_attention_cache: transformers.cache_utils.DynamicCache,
105-
) -> transformers.cache_utils.EncoderDecoderCache:
106-
"Creates an EncoderDecoderCache."
107-
return transformers.cache_utils.EncoderDecoderCache(
108-
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
109-
)
110-
111-
"""

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,21 @@ def convert_dynamic_axes_into_dynamic_shapes(
183183
)
184184

185185
return (), updated_kwargs, dynamic_shapes
186+
187+
188+
def use_dyn_not_str(dynamic_shapes: Any) -> Any:
189+
"""
190+
Some functions returns dynamic shapes as string.
191+
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
192+
"""
193+
if isinstance(dynamic_shapes, list):
194+
return [use_dyn_not_str(a) for a in dynamic_shapes]
195+
if isinstance(dynamic_shapes, tuple):
196+
return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
197+
if isinstance(dynamic_shapes, dict):
198+
return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()}
199+
if isinstance(dynamic_shapes, set):
200+
return {use_dyn_not_str(a) for a in dynamic_shapes}
201+
if isinstance(dynamic_shapes, str):
202+
return torch.export.Dim.DYNAMIC
203+
return dynamic_shapes

0 commit comments

Comments
 (0)