Skip to content

Commit 2be018c

Browse files
committed
fix
1 parent 6a336be commit 2be018c

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.7.13
55
++++++
66

7-
* :pr:`256`: add a set of inputs checking models works for an empty cache on task text-generation
7+
* :pr:`247`: supports more gemma models with ModelBuilder
8+
* :pr:`246`: add a set of inputs checking models works for an empty cache on task text-generation
89
* :pr:`237`: dummy inputs for google/gemma-3-4b-it
910
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
1011

_unittests/ut_tasks/test_tasks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,19 @@ def test_text_generation_empty_cache(self):
5454
model, inputs = data["model"], data["inputs"]
5555
self.assertIn("inputs_empty_cache", data)
5656
empty_inputs = torch_deepcopy(data["inputs_empty_cache"])
57-
expected = model(**empty_inputs)
57+
model(**torch_deepcopy(empty_inputs))
58+
expected = model(**torch_deepcopy(inputs))
5859
self.assertEqual(
5960
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
6061
)
6162
with torch_export_patches(patch_transformers=True, verbose=1):
6263
ep = torch.export.export(
6364
model,
6465
(),
65-
kwargs=inputs,
66+
kwargs=torch_deepcopy(inputs),
6667
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
6768
)
68-
got = ep.module()(**inputs)
69+
got = ep.module()(**torch_deepcopy(inputs))
6970
self.assertEqualArrayAny(expected, got)
7071

7172
@hide_stdout()

onnx_diagnostic/torch_models/validate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ def validate_model(
415415
``orteval10`` increases the verbosity.
416416
417417
.. versionchanged:: 0.7.13
418-
*inputs2* not only means a second set of inputs such as ``input_empty_cache``
418+
*inputs2* not only means a second set of inputs but many
419+
such as ``input_empty_cache``
419420
which refers to a set of inputs using an empty cache.
420421
"""
421422
validation_begin = time.perf_counter()

0 commit comments

Comments
 (0)