Skip to content

Commit a065e76

Browse files
committed
fix
1 parent 9a36531 commit a065e76

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
44
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
55
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
66

77

88
class TestShapeHelper(ExtTestCase):
9+
@requires_transformers("4.42")
10+
@requires_torch("2.7.99")
911
def test_all_dynamic_shape_from_inputs(self):
1012
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
1113
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
@@ -20,6 +22,8 @@ def test_all_dynamic_shape_from_inputs(self):
2022
ds,
2123
)
2224

25+
@requires_transformers("4.42")
26+
@requires_torch("2.7.99")
2327
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
2428
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
2529
print(self.string_type(data["inputs"], with_shape=True))

0 commit comments

Comments
 (0)