Skip to content

Commit 47f3a61

Browse files
committed
fix ut
1 parent 313d1ae commit 47f3a61

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,21 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
114114
self.assertEqual(len(unflat[0]), 2)
115115
self.assertIsInstance(unflat[0][0], list)
116116
self.assertEqual(len(unflat[0][0]), 3)
117+
self.assertEqual(
118+
"#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]]",
119+
self.string_type(unflat[0], with_shape=True),
120+
)
117121
self.assertEqual(
118122
"#2[#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]],"
119123
"#2[#3[T1s5x5x5,T1s5x5x5,T1s5x5x5],#3[T1s5x5x5,T1s5x5x5,T1s5x5x5]]]",
120124
self.string_type(unflat, with_shape=True),
121125
)
122126
self.assertEqual(
123-
"EncoderDecoderCache[serialized]("
124-
"#2[#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]],"
125-
"#2[#3[T1s5x5x5,T1s5x5x5,T1s5x5x5],#3[T1s5x5x5,T1s5x5x5,T1s5x5x5]]])",
127+
"EncoderDecoderCache(self_attention_cache=DynamicCache("
128+
"key_cache=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4], value_cache=#3"
129+
"[T1s4x4x4,T1s4x4x4,T1s4x4x4]), cross_attention_cache=DynamicCache"
130+
"(key_cache=#3[T1s5x5x5,T1s5x5x5,T1s5x5x5], value_cache=#3"
131+
"[T1s5x5x5,T1s5x5x5,T1s5x5x5]))",
126132
self.string_type(c2, with_shape=True),
127133
)
128134

_unittests/ut_torch_export_patches/test_patch_inputs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_1(self):
3838
self.assertEqual(
3939
(
4040
"dict(input_ids:T7s2x8,attention_mask:T7s2x8,position_ids:T7s2x8,"
41-
"past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x3x96],#1[T1s2x1x3x96]]))"
41+
"past_key_values:DynamicCache(key_cache=#1[T1s2x1x3x96], "
42+
"value_cache=#1[T1s2x1x3x96]))"
4243
),
4344
string_type(res[1], with_shape=True),
4445
)
@@ -107,7 +108,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
107108
self.assertEqual(
108109
(
109110
"dict(input_ids:T7s2x8,attention_mask:T7s2x8,position_ids:T7s2x8,"
110-
"past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x3x96],#1[T1s2x1x3x96]]))"
111+
"past_key_values:DynamicCache(key_cache=#1[T1s2x1x3x96], "
112+
"value_cache=#1[T1s2x1x3x96]))"
111113
),
112114
string_type(res[1], with_shape=True),
113115
)

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_base_model_output_unflatten_flatten(self):
160160
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
161161
with bypass_export_some_errors(patch_transformers=True):
162162
flat, _spec = torch.utils._pytree.tree_flatten(bo)
163-
unflat = flatten_unflatten_for_dynamic_shapes(bo)
163+
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
164164
self.assertIsInstance(unflat, dict)
165165
self.assertEqual(list(unflat), ["last_hidden_state"])
166166

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
1818
for the strict and non strict mode.
1919
:return: the serialized object
2020
"""
21+
if isinstance(obj, torch.Tensor):
22+
return obj
2123
flat, spec = torch.utils._pytree.tree_flatten(obj)
2224
start = 0
2325
end = 0
2426
subtrees = []
2527
for subspec in spec.children_specs:
2628
end += subspec.num_leaves
27-
if use_dict and subspec.type is dict:
29+
if use_dict and (subspec.type is dict or subspec.context):
2830
value = subspec.unflatten(flat[start:end])
2931
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
3032
else:
31-
value = flat[start:end]
33+
value = subspec.unflatten(flat[start:end])
34+
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
3235
subtrees.append(value)
3336
start = end
34-
if subspec.type is dict:
37+
if use_dict and (spec.type is dict or spec.context):
3538
# This a dictionary.
3639
return dict(zip(spec.context, subtrees))
3740
# This is a list.

0 commit comments

Comments
 (0)