Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.11', '3.12']
transformers: ['4.48.3', '4.51.1', 'main']
transformers: ['4.48.3', '4.51.2', 'main']
torch: ['2.6', 'main']

steps:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.4.0
+++++

* :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny
* :pr:`45`: improve change_dynamic_dimension to fix some dimensions

0.3.0
Expand Down
1 change: 1 addition & 0 deletions _doc/api/helpers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ onnx_diagnostic.helpers
memory_peak
onnx_helper
ort_session
rt_helper
torch_test_helper

.. autofunction:: onnx_diagnostic.helpers.max_diff
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/helpers/rt_helper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.helpers.rt_helper
=================================

.. automodule:: onnx_diagnostic.helpers.rt_helper
:members:
:no-undoc-members:
2 changes: 1 addition & 1 deletion _doc/examples/plot_export_tiny_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers.ort_session import make_feeds
from onnx_diagnostic.helpers.rt_helper import make_feeds
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_helpers/test_ort_session_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from onnxruntime.capi import _pybind_state as ORTC
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.rt_helper import make_feeds
from onnx_diagnostic.helpers.ort_session import (
InferenceSessionForNumpy,
InferenceSessionForTorch,
make_feeds,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_models.llms import get_tiny_llm
Expand Down
43 changes: 43 additions & 0 deletions _unittests/ut_torch_export_patches/test_onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,49 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]),
)

@ignore_warnings(UserWarning)
def test_exportable_dynamic_shapes_constraints(self):
import torch

class CustomCache:
def __init__(self, shape=None):
self.cache = [torch.zeros((shape)), torch.zeros((shape))] if shape else []

def flatten_cache(cache):
return [cache.cache], ["cache"]

def unflatten_cache(values, context, output_type=None):
cache = CustomCache()
cache.cache = values[0]
return cache

def flatten_with_keys_cache(d):
values, context = flatten_cache(d)
return [
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
], context

torch.utils._pytree.register_pytree_node(
CustomCache,
flatten_cache,
unflatten_cache,
serialized_type_name=f"{CustomCache.__module__}.{CustomCache.__name__}",
flatten_with_keys_fn=flatten_with_keys_cache,
)

class Model(torch.nn.Module):
def forward(self, x, cache):
return cache.cache[0][0, :] + x

model = Model()
model.eval()
x, cache = torch.rand((2, 4)), CustomCache((2, 4))
model(x, cache)
DYN = torch.export.Dim.DYNAMIC
torch.export.export(
model, (x, cache), dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}]])
)


if __name__ == "__main__":
unittest.main(verbosity=2)
156 changes: 156 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import unittest
import torch
from transformers.modeling_outputs import BaseModelOutput
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
bypass_export_some_errors,
)
from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy


class TestPatchSerialization(ExtTestCase):
@ignore_warnings(UserWarning)
def test_encoder_decoder_cache_flatten(self):
cache = make_encoder_decoder_cache(
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
)
with bypass_export_some_errors():
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]",
self.string_type(flat, with_shape=True),
)
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(cache, with_shape=True, with_min_max=True),
self.string_type(cache2, with_shape=True, with_min_max=True),
)

@ignore_warnings(UserWarning)
def test_encoder_decoder_cache_deepcopy(self):
cache = make_encoder_decoder_cache(
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
)
with bypass_export_some_errors():
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)

@ignore_warnings(UserWarning)
def test_encoder_decoder_cache_export(self):
class Model(torch.nn.Module):
def forward(self, cache):
return cache.self_attention_cache.key_cache[0]

cache1 = make_dynamic_cache(
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
)
cache2 = make_dynamic_cache(
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
)

cache = make_encoder_decoder_cache(cache1, cache2)
model = Model()
model(cache)
DYN = torch.export.Dim.DYNAMIC
ds = [
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
]

with bypass_export_some_errors(patch_transformers=True):
torch.export.export(model, (cache,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
def test_dynamic_cache_flatten(self):
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
with bypass_export_some_errors():
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#2[T1s4x4x4,T1s4x4x4]",
self.string_type(flat, with_shape=True),
)
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(cache, with_shape=True, with_min_max=True),
self.string_type(cache2, with_shape=True, with_min_max=True),
)

@ignore_warnings(UserWarning)
def test_dynamic_cache_export(self):
class Model(torch.nn.Module):
def forward(self, cache):
return cache.key_cache[0]

cache = make_dynamic_cache(
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
)
model = Model()
model(cache)
DYN = torch.export.Dim.DYNAMIC
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]

with bypass_export_some_errors():
torch.export.export(model, (cache,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
def test_dynamic_cache_deepcopy(self):
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
with bypass_export_some_errors():
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)

@ignore_warnings(UserWarning)
def test_base_model_output_deepcopy(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
self.assertEqual(bo.__class__.__name__, "BaseModelOutput")
with bypass_export_some_errors():
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput")
self.assertEqualAny([bo], bo2)

@ignore_warnings(UserWarning)
def test_base_model_output_string_type(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
with bypass_export_some_errors():
self.assertEqual(
"BaseModelOutput(last_hidden_state:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

@ignore_warnings(UserWarning)
def test_base_model_output_flatten(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
with bypass_export_some_errors():
flat, _spec = torch.utils._pytree.tree_flatten(bo)
self.assertEqual(
"#1[T1s4x4x4]",
self.string_type(flat, with_shape=True),
)
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(bo, with_shape=True, with_min_max=True),
self.string_type(bo2, with_shape=True, with_min_max=True),
)

@ignore_warnings(UserWarning)
def test_base_model_output_export(self):
class Model(torch.nn.Module):
def forward(self, cache):
return cache.last_hidden_state[0]

bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
model = Model()
model(bo)
DYN = torch.export.Dim.DYNAMIC
ds = [{0: DYN}]

with bypass_export_some_errors():
torch.export.export(model, (bo,), dynamic_shapes=(ds,))


if __name__ == "__main__":
unittest.main(verbosity=2)
69 changes: 69 additions & 0 deletions _unittests/ut_torch_models/test_hghub_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pprint
import unittest
import torch
import transformers
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
Expand All @@ -14,6 +15,7 @@
)
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class TestHuggingFaceHubModel(ExtTestCase):
Expand Down Expand Up @@ -104,6 +106,72 @@ def test_get_untrained_model_with_inputs_text2text_generation(self):
raise unittest.SkipTest(f"not working for {mid!r}")
model(**inputs)

@hide_stdout()
def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
mid = "openai/whisper-tiny"
data = get_untrained_model_with_inputs(mid, verbose=1)
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
Dim = torch.export.Dim
self.maxDiff = None
self.assertIn("{0:Dim(batch),1:Dim(seq_length)}", self.string_type(ds))
self.assertEqualAny(
{
"decoder_input_ids": {
0: Dim("batch", min=1, max=1024),
1: Dim("seq_length", min=1, max=4096),
},
"cache_position": {0: Dim("seq_length", min=1, max=4096)},
"encoder_outputs": [{0: Dim("batch", min=1, max=1024)}],
"past_key_values": [
[
[
{0: Dim("batch", min=1, max=1024)},
{0: Dim("batch", min=1, max=1024)},
],
[
{0: Dim("batch", min=1, max=1024)},
{0: Dim("batch", min=1, max=1024)},
],
],
[
[
{0: Dim("batch", min=1, max=1024)},
{0: Dim("batch", min=1, max=1024)},
],
[
{0: Dim("batch", min=1, max=1024)},
{0: Dim("batch", min=1, max=1024)},
],
],
],
},
ds,
)
model(**inputs)
self.assertEqual(
"#1[T1r3]",
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
)
with bypass_export_some_errors(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
self.assertEqual(
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
self.string_type(flat),
)
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
with bypass_export_some_errors(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
self.assertEqual(
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
self.string_type(flat),
)
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)

@hide_stdout()
def test_get_untrained_model_with_inputs_imagetext2text_generation(self):
mid = "HuggingFaceM4/tiny-random-idefics"
Expand Down Expand Up @@ -131,6 +199,7 @@ def _diff(c1, c2):
for mid in load_models_testing():
with self.subTest(mid=mid):
if mid in {
"hf-internal-testing/tiny-random-BeitForImageClassification",
"hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation",
"hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
"fxmarty/pix2struct-tiny-random",
Expand Down
Loading
Loading