Skip to content

Commit 8b9d673

Browse files
committed
fix ut
1 parent 659896b commit 8b9d673

File tree

4 files changed

+37
-2
lines changed

4 files changed

+37
-2
lines changed

_doc/api/torch_models/index.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
onnx_diagnostic.torch_models
2+
============================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
:caption: submodules
7+
8+
llms
9+
10+
.. automodule:: onnx_diagnostic.torch_models
11+
:members:
12+
:no-undoc-members:

_doc/api/torch_models/llms.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.llms
3+
=================================
4+
5+
.. automodule:: onnx_diagnostic.torch.models.llms
6+
:members:
7+
:no-undoc-members:

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import subprocess
66
import time
77
from onnx_diagnostic import __file__ as onnx_diagnostic_file
8-
from onnx_diagnostic.ext_test_case import ExtTestCase, is_windows
8+
from onnx_diagnostic.ext_test_case import ExtTestCase, is_windows, has_transformers
9+
910

1011
VERBOSE = 0
1112
ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_diagnostic_file, "..", "..")))
@@ -69,6 +70,13 @@ def add_test_methods(cls):
6970
continue
7071
reason = None
7172

73+
if (
74+
not reason
75+
and name in {"plot_export_tiny_llm.py"}
76+
and not has_transformers("4.50")
77+
):
78+
reason = "transformers<4.50"
79+
7280
if reason:
7381

7482
@unittest.skip(reason)

onnx_diagnostic/ext_test_case.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,13 +774,21 @@ def requires_sklearn(version: str, msg: str = "") -> Callable:
774774

775775

776776
def has_torch(version: str) -> bool:
777-
"Returns True if torch verions is higher."
777+
"Returns True if torch transformers is higher."
778778
import packaging.version as pv
779779
import torch
780780

781781
return pv.Version(torch.__version__) >= pv.Version(version)
782782

783783

784+
def has_transformers(version: str) -> bool:
785+
"Returns True if transformers version is higher."
786+
import packaging.version as pv
787+
import transformers
788+
789+
return pv.Version(transformers.__version__) >= pv.Version(version)
790+
791+
784792
def requires_torch(version: str, msg: str = "") -> Callable:
785793
"""Skips a unit test if :epkg:`pytorch` is not recent enough."""
786794
import packaging.version as pv

0 commit comments

Comments
 (0)