Skip to content

Commit d8e0dd8

Browse files
authored
Add set of inputs for empty cache (#246)
* Add set of inputs for empty cache * patches * fix * fix
1 parent be1f54f commit d8e0dd8

File tree

14 files changed

+178
-65
lines changed

14 files changed

+178
-65
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
print_all: false
4343
timeout: 2
4444
retry_count# : 2
45-
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/
45+
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/
4747
# force_pass : true

.github/workflows/ci.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
20-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.55.4', '4.56.1', 'main']
20+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', 'main']
2121
torch: ['2.8', 'main']
2222
exclude:
2323
- python: '3.10'
@@ -26,30 +26,28 @@ jobs:
2626
transformers: 'main'
2727
- python: '3.10'
2828
transformers: '4.52.4'
29-
- python: '3.10'
30-
transformers: '4.53.3'
3129
- python: '3.10'
3230
transformers: '4.55.4'
3331
- python: '3.10'
34-
transformers: '4.56.1'
32+
transformers: '4.56.2'
3533
- python: '3.11'
3634
torch: 'main'
37-
- python: '3.11'
38-
transformers: '4.53.3'
3935
- python: '3.11'
4036
transformers: 'main'
4137
- python: '3.11'
4238
transformers: '4.55.4'
4339
- python: '3.11'
44-
transformers: '4.56.1'
40+
transformers: '4.56.2'
4541
- python: '3.13'
4642
torch: '2.8'
4743
- python: '3.13'
4844
transformers: '4.48.3'
4945
- python: '3.13'
5046
transformers: '4.51.3'
5147
- python: '3.13'
52-
transformers: '4.52.4'
48+
transformers: '4.55.4'
49+
- python: '3.13'
50+
transformers: '4.56.2'
5351
steps:
5452
- uses: actions/checkout@v3
5553

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.13
55
++++++
66

7+
* :pr:`247`: supports more gemma models with ModelBuilder
8+
* :pr:`246`: add a set of inputs checking models works for an empty cache on task text-generation
79
* :pr:`237`: dummy inputs for google/gemma-3-4b-it
810
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
911

_doc/patches.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ Here is the list of available patches:
9191

9292
for name, cls in p.__dict__.items():
9393
if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
94-
print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")
94+
print(
95+
f"{cls._PATCHED_CLASS_.__name__}: "
96+
f"{', '.join([_ for _ in cls._PATCHES_ if _ is not None])}"
97+
)
9598

9699
Cache serialization
97100
===================

_doc/status/patches_coverage.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ for transformers.
3232

3333
for name, cls in p.__dict__.items():
3434
if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
35-
print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")
35+
print(
36+
f"{cls._PATCHED_CLASS_.__name__}: "
37+
f"{', '.join([_ for _ in cls._PATCHES_ if _ is not None])}"
38+
)
3639

3740
Half Automated Rewrites for Control Flows
3841
=========================================

_unittests/ut_tasks/test_tasks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ def test_text_generation(self):
4848
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4949
)
5050

51+
def test_text_generation_empty_cache(self):
52+
mid = "arnir0/Tiny-LLM"
53+
data = get_untrained_model_with_inputs(mid, add_second_input=True)
54+
model, inputs = data["model"], data["inputs"]
55+
self.assertIn("inputs_empty_cache", data)
56+
empty_inputs = torch_deepcopy(data["inputs_empty_cache"])
57+
model(**torch_deepcopy(empty_inputs))
58+
expected = model(**torch_deepcopy(inputs))
59+
self.assertEqual(
60+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
61+
)
62+
with torch_export_patches(patch_transformers=True, verbose=1):
63+
ep = torch.export.export(
64+
model,
65+
(),
66+
kwargs=torch_deepcopy(inputs),
67+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
68+
)
69+
got = ep.module()(**torch_deepcopy(inputs))
70+
self.assertEqualArrayAny(expected, got)
71+
5172
@hide_stdout()
5273
def test_automatic_speech_recognition_float32(self):
5374
mid = "openai/whisper-tiny"

_unittests/ut_tasks/try_tasks.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
77
from onnx_diagnostic.helpers.torch_helper import steal_forward
8-
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
98
from onnx_diagnostic.torch_export_patches import torch_export_patches
9+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1010

1111

1212
class TestHuggingFaceHubModel(ExtTestCase):
@@ -132,6 +132,52 @@ def test_text2text_generation_static(self):
132132
)
133133
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
134134

135+
@never_test()
136+
def test_text_generation_tiny_llm(self):
137+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k tiny_llm
138+
"""
139+
dict(cache_position:T7s21,
140+
past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),
141+
input_ids:T7s1x21,
142+
position_ids:T7s1x21
143+
attention_mask:T1s1x21)
144+
dict(cache_position:T7s1,
145+
past_key_values:DynamicCache(key_cache=#32[T1s1x8x21x128,...],
146+
value_cache=#32[T1s1x8x21x128,...]),
147+
input_ids:T7s1x21,
148+
position_ids:T7s1x1
149+
attention_mask:T1s1x1)
150+
"""
151+
from transformers import AutoTokenizer, AutoModelForCausalLM
152+
153+
tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM")
154+
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-mini-instruct")
155+
156+
text = "def greet(user): print(f'hello <extra_id_0>!')"
157+
input_ids = tokenizer(text, return_tensors="pt").input_ids.reshape((1, -1))
158+
mask = (
159+
torch.tensor([1 for i in range(input_ids.shape[1])])
160+
.to(torch.int64)
161+
.reshape((1, -1))
162+
)
163+
position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).reshape((1, -1))
164+
165+
# simply generate a single sequence
166+
print()
167+
with (
168+
torch_export_patches(
169+
patch_transformers=True, patch_torch=False, patch_sympy=False
170+
),
171+
steal_forward(model),
172+
):
173+
generated_ids = model.generate(
174+
input_ids=input_ids,
175+
max_length=100,
176+
attention_mask=mask,
177+
position_ids=position_ids,
178+
)
179+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
180+
135181
@never_test()
136182
def test_text_generation_phi4_mini(self):
137183
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def test_m_validate_model_vit_model(self):
227227
self.assertIsInstance(summary, dict)
228228
self.assertIsInstance(data, dict)
229229
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3)
230-
self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3)
230+
self.assertLess(summary["disc_onnx_ort_run22_abs"], 1e-3)
231231
self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"])
232232
self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"])
233233
self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"])

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,12 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
779779

780780

781781
def torch_deepcopy(value: Any) -> Any:
782-
"""Makes a deepcopy."""
782+
"""
783+
Makes a deep copy.
784+
785+
:param value: any value
786+
:return: a deep copy
787+
"""
783788
if value is None:
784789
return None
785790
if isinstance(value, (int, float, str)):

onnx_diagnostic/tasks/text_generation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,21 @@ def get_inputs(
269269
add_second_input=0,
270270
**kwargs,
271271
)["inputs"]
272+
res["inputs_empty_cache"] = get_inputs(
273+
model=model,
274+
config=config,
275+
dummy_max_token_id=dummy_max_token_id,
276+
num_hidden_layers=num_hidden_layers,
277+
batch_size=batch_size,
278+
sequence_length=0,
279+
sequence_length2=sequence_length2,
280+
dynamic_rope=dynamic_rope,
281+
num_key_value_heads=num_key_value_heads,
282+
head_dim=head_dim,
283+
cls_cache=cls_cache,
284+
add_second_input=0,
285+
**kwargs,
286+
)["inputs"]
272287
return res
273288

274289

0 commit comments

Comments
 (0)