Skip to content

Commit 085a1d6

Browse files
committed
lint
1 parent e724b5e commit 085a1d6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

_unittests/ut_export/test_serialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
44
from onnx_diagnostic.helpers import string_type
55
from onnx_diagnostic.helpers.cache_helper import (
66
make_dynamic_cache,
@@ -19,6 +19,7 @@ def _get_cache(self, n_layers=2, bsize=2, nheads=4, slen=1, dim=7):
1919
]
2020
)
2121

22+
@requires_transformers("4.50")
2223
def test_dynamic_cache(self):
2324
class Model(torch.nn.Module):
2425
def forward(self, cache):
@@ -31,6 +32,7 @@ def forward(self, cache):
3132
exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes)
3233
self.assertNotEmpty(exp)
3334

35+
@requires_transformers("4.50")
3436
def test_dynamic_cache_flat_unflat(self):
3537
class Model(torch.nn.Module):
3638
def forward(self, cache):

_unittests/ut_torch_export_patches/test_patch_inputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import torch
33
import transformers
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.torch_export_patches.patch_inputs import (
77
convert_dynamic_axes_into_dynamic_shapes,
@@ -10,6 +10,7 @@
1010

1111
class TestPatchInputs(ExtTestCase):
1212
@hide_stdout()
13+
@requires_transformers("4.50")
1314
def test_convert_dynamic_axes_into_dynamic_shapes_1(self):
1415
args = (
1516
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
@@ -55,6 +56,7 @@ def test_convert_dynamic_axes_into_dynamic_shapes_1(self):
5556
)
5657

5758
@hide_stdout()
59+
@requires_transformers("4.50")
5860
def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
5961
args = (
6062
torch.randint(0, 10, size=(2, 8)).to(torch.int64),

0 commit comments

Comments
 (0)