Skip to content

Commit 313d1ae

Browse files
committed
f
1 parent c284e7c commit 313d1ae

File tree

4 files changed

+7
-5
lines changed

4 files changed

+7
-5
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_unflatten_flatten_dynamic_cache(self):
8080
"#2[#1[T1s4x4x4],#1[T1s4x4x4]]", self.string_type(unflat, with_shape=True)
8181
)
8282
self.assertEqual(
83-
"DynamicCache[serialized](#2[#1[T1s4x4x4],#1[T1s4x4x4]])",
83+
"DynamicCache(key_cache=#1[T1s4x4x4], value_cache=#1[T1s4x4x4])",
8484
self.string_type(c1, with_shape=True),
8585
)
8686

_unittests/ut_tasks/try_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_automatic_speech_recognition(self):
184184

185185
# generate token ids
186186
print()
187-
with steal_forward(model):
187+
with steal_forward(model.model.decoder):
188188
predicted_ids = model.generate(
189189
input_features, forced_decoder_ids=forced_decoder_ids
190190
)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _generic_walker_step(
342342
else None
343343
)
344344
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
345-
assert set(inputs) is set(ds), (
345+
assert set(inputs) == set(ds), (
346346
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
347347
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
348348
)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
2424
subtrees = []
2525
for subspec in spec.children_specs:
2626
end += subspec.num_leaves
27-
value = subspec.unflatten(flat[start:end])
28-
if subspec.type is dict:
27+
if use_dict and subspec.type is dict:
28+
value = subspec.unflatten(flat[start:end])
2929
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
30+
else:
31+
value = flat[start:end]
3032
subtrees.append(value)
3133
start = end
3234
if subspec.type is dict:

0 commit comments

Comments
 (0)