diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 6970b954..ee206cfb 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,12 @@ Change Logs =========== +0.7.3 ++++++ + +* :pr:`173`: fixes function to_any for BaseModelOutput + + 0.7.2 +++++ diff --git a/_doc/index.rst b/_doc/index.rst index 0022090f..322c25cb 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -213,9 +213,7 @@ The function replaces dynamic dimensions defined as strings by Older versions ============== -* `0.7.2 <../v0.7.2/index.html>`_ -* `0.7.1 <../v0.7.1/index.html>`_ -* `0.7.0 <../v0.7.0/index.html>`_ +* `0.7.3 <../v0.7.3/index.html>`_ * `0.6.3 <../v0.6.3/index.html>`_ * `0.5.0 <../v0.5.0/index.html>`_ * `0.4.4 <../v0.4.4/index.html>`_ diff --git a/_doc/patches.rst b/_doc/patches.rst index a2dde062..d174c3df 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -104,7 +104,8 @@ and triggered by ``with torch_export_patches(patch_transformers=True)``. This function does one class, :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization` does all known classes. -It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization` +It can be undone with +:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization` or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`. Here is the list of supported caches: diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 388186a7..80936f30 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -6,6 +6,7 @@ has_transformers, requires_transformers, ) +from onnx_diagnostic.helpers.torch_helper import to_any from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -42,12 +43,13 @@ def test_text_generation(self): ) @hide_stdout() - def test_automatic_speech_recognition(self): + def test_automatic_speech_recognition_float32(self): mid = "openai/whisper-tiny" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) self.assertEqual(data["task"], "automatic-speech-recognition") self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)]) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**data["inputs"]) model(**data["inputs2"]) Dim = torch.export.Dim self.maxDiff = None @@ -113,6 +115,83 @@ def test_automatic_speech_recognition(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) + @hide_stdout() + def test_automatic_speech_recognition_float16(self): + mid = "openai/whisper-tiny" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "automatic-speech-recognition") + self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)]) + self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"])) + data["inputs"] = to_any(data["inputs"], torch.float16) + self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"])) + data["inputs2"] = to_any(data["inputs2"], torch.float16) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model = to_any(model, torch.float16) + model(**data["inputs2"]) + Dim = torch.export.Dim + self.maxDiff = None + self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds)) + self.assertEqualAny( + { + "decoder_input_ids": { + 0: Dim("batch", min=1, max=1024), + 1: "seq_length", + }, + "cache_position": {0: "seq_length"}, + "encoder_outputs": [{0: Dim("batch", min=1, max=1024)}], + "past_key_values": [ + [ + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + ], + [ + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + ], + ], + }, + ds, + ) + self.assertEqual( + "#1[T10r3]", + self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]), + ) + with torch_export_patches(patch_transformers=True, verbose=10): + model(**inputs) + flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0] + self.assertIsInstance(flat, list) + self.assertIsInstance(flat[0], torch.Tensor) + self.assertEqual( + "#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]", + self.string_type(flat), + ) + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + with torch_export_patches(patch_transformers=True, verbose=10): + flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0] + self.assertIsInstance(flat, list) + self.assertIsInstance(flat[0], torch.Tensor) + self.assertEqual( + "#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]", + self.string_type(flat), + ) + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + @hide_stdout() def test_fill_mask(self): mid = "google-bert/bert-base-multilingual-cased" diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 58b65182..afa1684b 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.7.2" +__version__ = "0.7.3" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index be20beb4..27fe4081 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -717,7 +717,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: return tuple(to_any(t, to_value) for t in value) if isinstance(value, set): return {to_any(t, to_value) for t in value} - if isinstance(value, dict): + if type(value) is dict: return {k: to_any(t, to_value) for k, t in value.items()} if value.__class__.__name__ == "DynamicCache": return make_dynamic_cache( diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 7ddbaa06..d06939ac 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -214,8 +214,8 @@ def update( if len(self.key_cache) <= layer_idx: # There may be skipped layers, fill them with empty lists for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) + self.key_cache.append(torch.tensor([], dtype=key_states.dtype)) + self.value_cache.append(torch.tensor([], dtype=key_states.dtype)) self.key_cache.append(key_states) self.value_cache.append(value_states) elif not self.key_cache[ @@ -231,7 +231,6 @@ def update( self.value_cache[layer_idx] = torch.cat( [self.value_cache[layer_idx], value_states], dim=-2 ) - return self.key_cache[layer_idx], self.value_cache[layer_idx] def crop(self, max_length: int):