Skip to content

Commit 8ce2aaa

Browse files
authored
Fix bug in change_dimension with custom classes (#72)
* fix bug in change_dimnension * fix bypass * inputs2 * add second input * change * split * fix * disable a test * fix * ruff * use the latest * fix
1 parent 345b783 commit 8ce2aaa

26 files changed

+459
-87
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,16 @@ jobs:
9494
- name: run tests bypassed
9595
run: PYTHONPATH=. python _unittests/ut_torch_models/test_tiny_llms_bypassed.py
9696

97+
- name: test image_classification
98+
run: PYTHONPATH=. python _unittests/ut_tasks/test_tasks_image_classification.py
99+
100+
- name: test zero_shot_image_classification
101+
run: PYTHONPATH=. python _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py
102+
97103
- name: run tests
98104
run: |
99105
pip install pytest
100-
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py
106+
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py --ignore _unittests/ut_tasks/test_tasks_image_classification.py
101107
102108
- name: run backend tests python
103109
run: PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.1
55
+++++
66

7+
* :pr:`72`: fix change_dynamic_dimension for custom classes
78
* :pr:`70`: support models options in command lines
89

910
0.4.0

_doc/api/tasks/index.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ All submodules contains the three following functions:
99
* ``random_input_kwargs(config) -> kwargs, get_inputs``:
1010
produces values ``get_inputs`` can take to generate dummy inputs
1111
suitable for a model defined by its configuration
12-
* ``get_inputs(model, config, *args, **kwargs) -> dict(inputs=..., dynamic_shapes=...)``:
13-
generates the dummy inputs and dynamic shapes for a specific model and configuration.
12+
* ``get_inputs(model, config, *args, add_second_input=False, **kwargs) -> dict(inputs=..., dynamic_shapes=...)``:
13+
generates the dummy inputs and dynamic shapes for a specific model and configuration,
14+
if ``add_second_input`` is True, the function should return a different set of inputs,
15+
with different values for the dynamic dimension. This is usually better to
16+
rely on the function as the dynamic dimensions may be correlated.
1417

1518
For a specific task, you would write:
1619

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import torch
3+
import transformers
34
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
45
from onnx_diagnostic.helpers import string_type
56
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
@@ -742,6 +743,18 @@ def test_couple_input_ds_change_dynamic_dimensions_fixed(self):
742743
self.assertEqual((1, 5, 8), new_input["A"].shape)
743744
self.assertEqual((1, 50), new_input["B"].shape)
744745

746+
def test_couple_input_ds_change_dynamic_dimensions_dynamic_cache(self):
747+
inst = CoupleInputsDynamicShapes(
748+
(),
749+
{"A": make_dynamic_cache([(torch.ones((2, 2, 2, 2)), torch.ones((2, 2, 2, 2)))])},
750+
{"A": [[{0: "batch", 2: "last"}], [{0: "batch", 2: "last"}]]},
751+
)
752+
with bypass_export_some_errors(patch_transformers=True):
753+
new_inputs = inst.change_dynamic_dimensions()
754+
self.assertIsInstance(new_inputs["A"], transformers.cache_utils.DynamicCache)
755+
self.assertEqual((3, 2, 3, 2), new_inputs["A"].key_cache[0].shape)
756+
self.assertEqual((3, 2, 3, 2), new_inputs["A"].value_cache[0].shape)
757+
745758
@requires_transformers("4.51")
746759
def test_dynamic_cache_replace_by_string(self):
747760
n_layers = 2

_unittests/ut_helpers/test_helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
dtype_to_tensor_dtype,
4040
)
4141
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
42+
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
43+
4244

4345
TFLOAT = onnx.TensorProto.FLOAT
4446

@@ -484,6 +486,11 @@ def test_flatten_encoder_decoder_cache(self):
484486
s = string_type(inputs)
485487
self.assertIn("EncoderDecoderCache", s)
486488

489+
def test_string_typeçconfig(self):
490+
conf = get_pretrained_config("microsoft/phi-2")
491+
s = string_type(conf)
492+
self.assertStartsWith("PhiConfig(**{", s)
493+
487494

488495
if __name__ == "__main__":
489496
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers, has_torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
44
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
55
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
66
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -10,11 +10,27 @@ class TestTasks(ExtTestCase):
1010
@hide_stdout()
1111
def test_text2text_generation(self):
1212
mid = "sshleifer/tiny-marian-en-de"
13-
data = get_untrained_model_with_inputs(mid, verbose=1)
13+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
14+
self.assertEqual(data["task"], "text2text-generation")
1415
self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)])
1516
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
1617
raise unittest.SkipTest(f"not working for {mid!r}")
1718
model(**inputs)
19+
model(**data["inputs2"])
20+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
21+
torch.export.export(
22+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
23+
)
24+
25+
@hide_stdout()
26+
def test_text_generation(self):
27+
mid = "arnir0/Tiny-LLM"
28+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
29+
self.assertEqual(data["task"], "text-generation")
30+
self.assertIn((data["size"], data["n_weights"]), [(51955968, 12988992)])
31+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
32+
model(**inputs)
33+
model(**data["inputs2"])
1834
with bypass_export_some_errors(patch_transformers=True, verbose=10):
1935
torch.export.export(
2036
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -23,9 +39,11 @@ def test_text2text_generation(self):
2339
@hide_stdout()
2440
def test_automatic_speech_recognition(self):
2541
mid = "openai/whisper-tiny"
26-
data = get_untrained_model_with_inputs(mid, verbose=1)
42+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
43+
self.assertEqual(data["task"], "automatic-speech-recognition")
2744
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
2845
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
46+
model(**data["inputs2"])
2947
Dim = torch.export.Dim
3048
self.maxDiff = None
3149
self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds))
@@ -90,27 +108,15 @@ def test_automatic_speech_recognition(self):
90108
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
91109
)
92110

93-
@hide_stdout()
94-
def test_imagetext2text_generation(self):
95-
mid = "HuggingFaceM4/tiny-random-idefics"
96-
data = get_untrained_model_with_inputs(mid, verbose=1)
97-
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
98-
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
99-
model(**inputs)
100-
if not has_torch("2.10"):
101-
raise unittest.SkipTest("sym_max does not work with dynamic dimension")
102-
with bypass_export_some_errors(patch_transformers=True, verbose=10):
103-
torch.export.export(
104-
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
105-
)
106-
107111
@hide_stdout()
108112
def test_fill_mask(self):
109113
mid = "google-bert/bert-base-multilingual-cased"
110-
data = get_untrained_model_with_inputs(mid, verbose=1)
114+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
115+
self.assertEqual(data["task"], "fill-mask")
111116
self.assertIn((data["size"], data["n_weights"]), [(428383212, 107095803)])
112117
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
113118
model(**inputs)
119+
model(**data["inputs2"])
114120
with bypass_export_some_errors(patch_transformers=True, verbose=10):
115121
torch.export.export(
116122
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -119,10 +125,12 @@ def test_fill_mask(self):
119125
@hide_stdout()
120126
def test_feature_extraction(self):
121127
mid = "facebook/bart-base"
122-
data = get_untrained_model_with_inputs(mid, verbose=1)
128+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
129+
self.assertEqual(data["task"], "feature-extraction")
123130
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
124131
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
125132
model(**inputs)
133+
model(**data["inputs2"])
126134
with bypass_export_some_errors(patch_transformers=True, verbose=10):
127135
torch.export.export(
128136
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -131,10 +139,12 @@ def test_feature_extraction(self):
131139
@hide_stdout()
132140
def test_text_classification(self):
133141
mid = "Intel/bert-base-uncased-mrpc"
134-
data = get_untrained_model_with_inputs(mid, verbose=1)
142+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
143+
self.assertEqual(data["task"], "text-classification")
135144
self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)])
136145
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
137146
model(**inputs)
147+
model(**data["inputs2"])
138148
with bypass_export_some_errors(patch_transformers=True, verbose=10):
139149
torch.export.export(
140150
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -143,10 +153,12 @@ def test_text_classification(self):
143153
@hide_stdout()
144154
def test_sentence_similary(self):
145155
mid = "sentence-transformers/all-MiniLM-L6-v1"
146-
data = get_untrained_model_with_inputs(mid, verbose=1)
156+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
157+
self.assertEqual(data["task"], "sentence-similarity")
147158
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
148159
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
149160
model(**inputs)
161+
model(**data["inputs2"])
150162
with bypass_export_some_errors(patch_transformers=True, verbose=10):
151163
torch.export.export(
152164
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -155,9 +167,11 @@ def test_sentence_similary(self):
155167
@hide_stdout()
156168
def test_falcon_mamba_dev(self):
157169
mid = "tiiuae/falcon-mamba-tiny-dev"
158-
data = get_untrained_model_with_inputs(mid, verbose=1)
170+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
171+
self.assertEqual(data["task"], "text-generation")
159172
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
160173
model(**inputs)
174+
model(**data["inputs2"])
161175
self.assertIn((data["size"], data["n_weights"]), [(138640384, 34660096)])
162176
if not has_transformers("4.55"):
163177
raise unittest.SkipTest("The model has control flow.")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
4+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
5+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
7+
8+
9+
class TestTasks(ExtTestCase):
10+
@hide_stdout()
11+
def test_image_classification(self):
12+
mid = "hf-internal-testing/tiny-random-BeitForImageClassification"
13+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
14+
self.assertEqual(data["task"], "image-classification")
15+
self.assertIn((data["size"], data["n_weights"]), [(56880, 14220)])
16+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
17+
model(**inputs)
18+
model(**data["inputs2"])
19+
if not has_transformers("4.51.999"):
20+
raise unittest.SkipTest("Requires transformers>=4.52")
21+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
22+
torch.export.export(
23+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
24+
)
25+
26+
27+
if __name__ == "__main__":
28+
unittest.main(verbosity=2)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers, has_torch
4+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
5+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
7+
8+
9+
class TestTasks(ExtTestCase):
10+
@hide_stdout()
11+
def test_image_text_to_text(self):
12+
mid = "HuggingFaceM4/tiny-random-idefics"
13+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
14+
self.assertEqual(data["task"], "image-text-to-text")
15+
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
16+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
17+
model(**inputs)
18+
model(**data["inputs2"])
19+
if not has_transformers("4.55"):
20+
raise unittest.SkipTest("The model has control flow.")
21+
if not has_torch("2.7.99"):
22+
raise unittest.SkipTest("sym_max does not work with dynamic dimension")
23+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
24+
torch.export.export(
25+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
26+
)
27+
28+
29+
if __name__ == "__main__":
30+
unittest.main(verbosity=2)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch
4+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
5+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
7+
8+
9+
class TestTasks(ExtTestCase):
10+
@requires_torch("2.7.99")
11+
@hide_stdout()
12+
def test_zero_shot_image_classification(self):
13+
mid = "openai/clip-vit-base-patch16"
14+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
15+
self.assertEqual(data["task"], "zero-shot-image-classification")
16+
self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)])
17+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
18+
model(**inputs)
19+
model(**data["inputs2"])
20+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
21+
torch.export.export(
22+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
23+
)
24+
25+
26+
if __name__ == "__main__":
27+
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,16 +363,20 @@ def _generic_walker_step(
363363
)
364364
if flatten_unflatten:
365365
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
366-
return cls._generic_walker_step(
366+
res = cls._generic_walker_step(
367367
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
368368
)
369-
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
369+
# Should we restore the original class?
370+
return res
371+
flat, spec = torch.utils._pytree.tree_flatten(inputs)
370372
if all(isinstance(t, torch.Tensor) for t in flat):
371373
# We need to flatten dynamic shapes as well
372374
ds = flatten_dynamic_shapes(ds)
373-
return cls._generic_walker_step(
375+
res = cls._generic_walker_step(
374376
processor, flat, ds, flatten_unflatten=flatten_unflatten
375377
)
378+
# Then we restore the original class.
379+
return torch.utils._pytree.tree_unflatten(res, spec)
376380

377381
class ChangeDimensionProcessor:
378382
def __init__(self, desired_values):

0 commit comments

Comments
 (0)