Skip to content

Commit c47979b

Browse files
committed
fix
1 parent 721d436 commit c47979b

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_text_generation_tiny_llm(self):
5454
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
5555
self.assertEqual(data["task"], "text-generation")
5656
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
57+
inputs_copied = torch_deepcopy(inputs)
5758
expected = model(**torch_deepcopy(inputs))
5859
model(**data["inputs2"])
5960
fake = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes=ds)[0]
@@ -62,7 +63,7 @@ def test_text_generation_tiny_llm(self):
6263
model, (), kwargs=fake, dynamic_shapes=use_dyn_not_str(ds), strict=False
6364
)
6465
# print(ep)
65-
got = ep.module()(**inputs)
66+
got = ep.module()(**inputs_copied)
6667
self.assertEqualAny(expected.past_key_values, got.past_key_values)
6768
self.assertEqualArray(expected.logits, got.logits)
6869

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ def make_dynamic_cache(
182182
layer.device = k.device
183183
layer.keys = k
184184
layer.values = v
185+
layer.is_initialized = True
186+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
187+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
188+
f"{len(key_value_pairs)} expected."
189+
)
185190
return finalize_cache(cache)
186191

187192
cache = transformers.cache_utils.DynamicCache(key_value_pairs)

0 commit comments

Comments
 (0)