Skip to content

Commit bb64c8e

Browse files
committed
fix ut
1 parent 25ad8f6 commit bb64c8e

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def test_all_dynamic_shape_from_cache(self):
2525
ds = all_dynamic_shape_from_inputs(cache)
2626
self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds)
2727

28-
@requires_transformers("4.52")
2928
@requires_torch("2.7.99")
3029
def test_all_dynamic_shape_all_transformers_cache(self):
3130
caches = [

_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def test_unet_2d_condition_output(self):
4949
# flatten_unflatten
5050
flat, _spec = torch.utils._pytree.tree_flatten(bo)
5151
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
52-
self.assertIsInstance(unflat, dict)
53-
self.assertEqual(list(unflat), ["sample"])
52+
self.assertIsInstance(unflat, list)
53+
self.assertEqual("#1[T1r3]", self.string_type(unflat))
5454

5555
# export
5656
class Model(torch.nn.Module):

_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def test_base_model_output_unflatten_flatten(self):
163163
with torch_export_patches(patch_transformers=True):
164164
flat, _spec = torch.utils._pytree.tree_flatten(bo)
165165
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
166-
self.assertIsInstance(unflat, dict)
167-
self.assertEqual(list(unflat), ["last_hidden_state"])
166+
self.assertIsInstance(unflat, list)
167+
self.assertEqual("#1[T1r3]", self.string_type(unflat))
168168

169169
@ignore_warnings(UserWarning)
170170
def test_base_sliding_window_cache_unflatten_flatten(self):
@@ -260,8 +260,10 @@ def test_static_cache(self):
260260
# flatten_unflatten
261261
flat, _spec = torch.utils._pytree.tree_flatten(bo)
262262
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
263-
self.assertIsInstance(unflat, dict)
264-
self.assertEqual(list(unflat), ["key_cache", "value_cache"])
263+
self.assertIsInstance(unflat, list)
264+
self.assertEqual(
265+
"#2[#3[T1r4,T1r4,T1r4],#3[T1r4,T1r4,T1r4]]", self.string_type(unflat)
266+
)
265267

266268
# export
267269
class Model(torch.nn.Module):

onnx_diagnostic/export/shape_helper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
99
All dimensions are considered as dynamic.
1010
``dim_prefix`` can be a string (the function uses it as a prefix),
1111
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12-
For caches, ``transformers>=4.52```is better.
1312
1413
.. runpython::
1514
:showcode:

0 commit comments

Comments
 (0)