Skip to content

Commit 28410cd

Browse files
committed
fix issues
1 parent 29012cd commit 28410cd

File tree

5 files changed

+97
-12
lines changed

5 files changed

+97
-12
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import unittest
22
import torch
3+
import transformers
34
from onnx_diagnostic.ext_test_case import ExtTestCase
45
from onnx_diagnostic.helpers import string_type
5-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6+
from onnx_diagnostic.helpers.cache_helper import (
7+
make_dynamic_cache,
8+
make_encoder_decoder_cache,
9+
flatten_unflatten_for_dynamic_shapes,
10+
)
611
from onnx_diagnostic.export import CoupleInputsDynamicShapes
712
from onnx_diagnostic.torch_export_patches.patch_inputs import (
813
convert_dynamic_axes_into_dynamic_shapes,
@@ -66,6 +71,61 @@ def test_replace_by(self):
6671
dsc = res["past_key_values"]
6772
self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc)
6873

74+
def test_unflatten_flatten_dynamic_cache(self):
75+
with bypass_export_some_errors(patch_transformers=True):
76+
c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
77+
self.assertIsInstance(c1, transformers.cache_utils.DynamicCache)
78+
unflat = flatten_unflatten_for_dynamic_shapes(c1)
79+
self.assertEqual(
80+
"#2[#1[T1s4x4x4],#1[T1s4x4x4]]", self.string_type(unflat, with_shape=True)
81+
)
82+
self.assertEqual(
83+
"DynamicCache[serialized](#2[#1[T1s4x4x4],#1[T1s4x4x4]])",
84+
self.string_type(c1, with_shape=True),
85+
)
86+
87+
def test_unflatten_flatten_encoder_decoder_cache(self):
88+
with bypass_export_some_errors(patch_transformers=True):
89+
c2 = make_encoder_decoder_cache(
90+
make_dynamic_cache(
91+
[
92+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
93+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
94+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
95+
]
96+
),
97+
make_dynamic_cache(
98+
[
99+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
100+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
101+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
102+
]
103+
),
104+
)
105+
self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache)
106+
flat, _spec = torch.utils._pytree.tree_flatten(c2)
107+
self.assertIsInstance(flat, list)
108+
self.assertEqual(len(flat), 12)
109+
self.assertIsInstance(flat[0], torch.Tensor)
110+
unflat = flatten_unflatten_for_dynamic_shapes(c2)
111+
self.assertIsInstance(unflat, list)
112+
self.assertEqual(len(unflat), 2)
113+
self.assertIsInstance(unflat[0], list)
114+
self.assertEqual(len(unflat[0]), 2)
115+
self.assertIsInstance(unflat[0][0], list)
116+
self.assertEqual(len(unflat[0][0]), 3)
117+
self.assertEqual(
118+
"#2[#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]],"
119+
"#2[#3[T1s5x5x5,T1s5x5x5,T1s5x5x5],#3[T1s5x5x5,T1s5x5x5,T1s5x5x5]]]",
120+
self.string_type(unflat, with_shape=True),
121+
)
122+
self.assertEqual(
123+
"EncoderDecoderCache[serialized]("
124+
"#2[#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]],"
125+
"#2[#3[T1s5x5x5,T1s5x5x5,T1s5x5x5],#3[T1s5x5x5,T1s5x5x5,T1s5x5x5]]])",
126+
self.string_type(c2, with_shape=True),
127+
)
128+
69129

70130
if __name__ == "__main__":
71131
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import torch
33
from transformers.modeling_outputs import BaseModelOutput
44
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
5-
from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache
5+
from onnx_diagnostic.helpers.cache_helper import (
6+
make_encoder_decoder_cache,
7+
make_dynamic_cache,
8+
flatten_unflatten_for_dynamic_shapes,
9+
)
610
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
711
bypass_export_some_errors,
812
)
@@ -151,6 +155,15 @@ def forward(self, cache):
151155
with bypass_export_some_errors():
152156
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
153157

158+
@ignore_warnings(UserWarning)
159+
def test_base_model_output_unflatten_flatten(self):
160+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
161+
with bypass_export_some_errors(patch_transformers=True):
162+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
163+
unflat = flatten_unflatten_for_dynamic_shapes(bo)
164+
self.assertIsInstance(unflat, dict)
165+
self.assertEqual(list(unflat), ["last_hidden_state"])
166+
154167

155168
if __name__ == "__main__":
156169
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,19 @@ def _generic_walker_step(
316316
return processor(inputs, ds)
317317
if isinstance(inputs, (int, float, str)):
318318
return None
319-
if isinstance(inputs, (tuple, list, dict)):
320-
assert type(ds) is type(
321-
inputs
322-
), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}"
319+
if type(inputs) in (tuple, list, dict):
320+
# Type must be strict, some custom classes can inherit from those.
321+
assert type(inputs) is type(ds), (
322+
f"Input type and dynamic shape type mush match but "
323+
f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
324+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
325+
)
323326
assert len(ds) == len(inputs), (
324327
f"Length mismatch between inputs {len(inputs)} "
325328
f"and ds={len(ds)}\n"
326329
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
327330
)
328-
if isinstance(inputs, (tuple, list)):
331+
if type(inputs) in (tuple, list):
329332
value = []
330333
for i, d in zip(inputs, ds):
331334
value.append(
@@ -338,9 +341,11 @@ def _generic_walker_step(
338341
if any(v is not None for v in value)
339342
else None
340343
)
341-
assert set(inputs) == set(
342-
ds
343-
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
344+
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
345+
assert set(inputs) is set(ds), (
346+
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
347+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
348+
)
344349
dvalue = {}
345350
for k, v in inputs.items():
346351
t = cls._generic_walker_step(

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,15 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any) -> Any:
1919
subtrees = []
2020
for subspec in spec.children_specs:
2121
end += subspec.num_leaves
22-
subtrees.append(subspec.unflatten(flat[start:end]))
22+
value = subspec.unflatten(flat[start:end])
23+
if not isinstance(value, (torch.Tensor, list)):
24+
value = flatten_unflatten_for_dynamic_shapes(value)
25+
subtrees.append(value)
2326
start = end
27+
if spec.context:
28+
# This a dictionary.
29+
return dict(zip(spec.context, subtrees))
30+
# This is a list.
2431
return subtrees
2532

2633

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_inputs(
7474
shapes = {
7575
"decoder_input_ids": {0: batch, 1: seq_length},
7676
"cache_position": {0: seq_length},
77-
"encoder_outputs": [{0: batch}],
77+
"encoder_outputs": {"last_hidden_state": {0: batch}},
7878
"past_key_values": [
7979
[
8080
[{0: batch} for _ in range(num_hidden_layers)],

0 commit comments

Comments
 (0)