Skip to content

Commit f9fa6e8

Browse files
committed
fix issues
1 parent bb64c8e commit f9fa6e8

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,21 @@ def flatten_unflatten_like_dynamic_shapes(obj):
7575
value = flatten_unflatten_like_dynamic_shapes(value)
7676
subtrees.append(value)
7777
start = end
78-
if spec.type is dict or spec.context:
78+
if spec.type is dict:
79+
# This a dictionary.
7980
return dict(zip(spec.context, subtrees))
8081
if spec.type is tuple:
8182
return tuple(subtrees)
82-
return subtrees
83+
if spec.type is list:
84+
return list(subtrees)
85+
if spec.context:
86+
# This is a custom class with attributes.
87+
# It is returned as a list.
88+
return list(subtrees)
89+
raise ValueError(
90+
f"Unable to interpret spec type {spec.type} "
91+
f"(type is {type(spec.type)}, context is {spec.context})."
92+
)
8393

8494

8595
def _align(inputs, ds):

_unittests/ut_export/test_shape_helper.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
make_sliding_window_cache,
1111
make_encoder_decoder_cache,
1212
make_static_cache,
13-
make_mamba_cache,
1413
)
1514
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1615
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -119,27 +118,6 @@ def test_all_dynamic_shape_all_transformers_cache(self):
119118
],
120119
],
121120
),
122-
(
123-
make_mamba_cache(
124-
[
125-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
126-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
127-
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
128-
]
129-
),
130-
[
131-
[
132-
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2"},
133-
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2"},
134-
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2"},
135-
],
136-
[
137-
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2"},
138-
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2"},
139-
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2"},
140-
],
141-
],
142-
),
143121
]
144122
with torch_export_patches(patch_transformers=True):
145123
for cache, exds in caches:

0 commit comments

Comments
 (0)