Skip to content

Commit 8c47d90

Browse files
authored
fix to_any for BaseModelOutput (#173)
* fix to_any for BaseModelOutput * doc
1 parent 39648aa commit 8c47d90

File tree

7 files changed

+93
-10
lines changed

7 files changed

+93
-10
lines changed

CHANGELOGS.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Change Logs
22
===========
33

4+
0.7.3
5+
+++++
6+
7+
* :pr:`173`: fixes function to_any for BaseModelOutput
8+
9+
410
0.7.2
511
+++++
612

_doc/index.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ The function replaces dynamic dimensions defined as strings by
213213
Older versions
214214
==============
215215

216-
* `0.7.2 <../v0.7.2/index.html>`_
217-
* `0.7.1 <../v0.7.1/index.html>`_
218-
* `0.7.0 <../v0.7.0/index.html>`_
216+
* `0.7.3 <../v0.7.3/index.html>`_
219217
* `0.6.3 <../v0.6.3/index.html>`_
220218
* `0.5.0 <../v0.5.0/index.html>`_
221219
* `0.4.4 <../v0.4.4/index.html>`_

_doc/patches.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ and triggered by ``with torch_export_patches(patch_transformers=True)``.
104104
This function does one class,
105105
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`
106106
does all known classes.
107-
It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`
107+
It can be undone with
108+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`
108109
or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`.
109110
Here is the list of supported caches:
110111

_unittests/ut_tasks/test_tasks.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
has_transformers,
77
requires_transformers,
88
)
9+
from onnx_diagnostic.helpers.torch_helper import to_any
910
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1011
from onnx_diagnostic.torch_export_patches import torch_export_patches
1112
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -42,12 +43,13 @@ def test_text_generation(self):
4243
)
4344

4445
@hide_stdout()
45-
def test_automatic_speech_recognition(self):
46+
def test_automatic_speech_recognition_float32(self):
4647
mid = "openai/whisper-tiny"
4748
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
4849
self.assertEqual(data["task"], "automatic-speech-recognition")
4950
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
5051
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
52+
model(**data["inputs"])
5153
model(**data["inputs2"])
5254
Dim = torch.export.Dim
5355
self.maxDiff = None
@@ -113,6 +115,83 @@ def test_automatic_speech_recognition(self):
113115
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
114116
)
115117

118+
@hide_stdout()
119+
def test_automatic_speech_recognition_float16(self):
120+
mid = "openai/whisper-tiny"
121+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
122+
self.assertEqual(data["task"], "automatic-speech-recognition")
123+
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
124+
self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"]))
125+
data["inputs"] = to_any(data["inputs"], torch.float16)
126+
self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"]))
127+
data["inputs2"] = to_any(data["inputs2"], torch.float16)
128+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
129+
model = to_any(model, torch.float16)
130+
model(**data["inputs2"])
131+
Dim = torch.export.Dim
132+
self.maxDiff = None
133+
self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds))
134+
self.assertEqualAny(
135+
{
136+
"decoder_input_ids": {
137+
0: Dim("batch", min=1, max=1024),
138+
1: "seq_length",
139+
},
140+
"cache_position": {0: "seq_length"},
141+
"encoder_outputs": [{0: Dim("batch", min=1, max=1024)}],
142+
"past_key_values": [
143+
[
144+
[
145+
{0: Dim("batch", min=1, max=1024)},
146+
{0: Dim("batch", min=1, max=1024)},
147+
],
148+
[
149+
{0: Dim("batch", min=1, max=1024)},
150+
{0: Dim("batch", min=1, max=1024)},
151+
],
152+
],
153+
[
154+
[
155+
{0: Dim("batch", min=1, max=1024)},
156+
{0: Dim("batch", min=1, max=1024)},
157+
],
158+
[
159+
{0: Dim("batch", min=1, max=1024)},
160+
{0: Dim("batch", min=1, max=1024)},
161+
],
162+
],
163+
],
164+
},
165+
ds,
166+
)
167+
self.assertEqual(
168+
"#1[T10r3]",
169+
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
170+
)
171+
with torch_export_patches(patch_transformers=True, verbose=10):
172+
model(**inputs)
173+
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
174+
self.assertIsInstance(flat, list)
175+
self.assertIsInstance(flat[0], torch.Tensor)
176+
self.assertEqual(
177+
"#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]",
178+
self.string_type(flat),
179+
)
180+
torch.export.export(
181+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
182+
)
183+
with torch_export_patches(patch_transformers=True, verbose=10):
184+
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
185+
self.assertIsInstance(flat, list)
186+
self.assertIsInstance(flat[0], torch.Tensor)
187+
self.assertEqual(
188+
"#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]",
189+
self.string_type(flat),
190+
)
191+
torch.export.export(
192+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
193+
)
194+
116195
@hide_stdout()
117196
def test_fill_mask(self):
118197
mid = "google-bert/bert-base-multilingual-cased"

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.7.2"
6+
__version__ = "0.7.3"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
717717
return tuple(to_any(t, to_value) for t in value)
718718
if isinstance(value, set):
719719
return {to_any(t, to_value) for t in value}
720-
if isinstance(value, dict):
720+
if type(value) is dict:
721721
return {k: to_any(t, to_value) for k, t in value.items()}
722722
if value.__class__.__name__ == "DynamicCache":
723723
return make_dynamic_cache(

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def update(
214214
if len(self.key_cache) <= layer_idx:
215215
# There may be skipped layers, fill them with empty lists
216216
for _ in range(len(self.key_cache), layer_idx):
217-
self.key_cache.append(torch.tensor([]))
218-
self.value_cache.append(torch.tensor([]))
217+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
218+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
219219
self.key_cache.append(key_states)
220220
self.value_cache.append(value_states)
221221
elif not self.key_cache[
@@ -231,7 +231,6 @@ def update(
231231
self.value_cache[layer_idx] = torch.cat(
232232
[self.value_cache[layer_idx], value_states], dim=-2
233233
)
234-
235234
return self.key_cache[layer_idx], self.value_cache[layer_idx]
236235

237236
def crop(self, max_length: int):

0 commit comments

Comments
 (0)