Skip to content

Commit 6989647

Browse files
committed
tests
1 parent 1a95582 commit 6989647

File tree

4 files changed

+5
-2
lines changed

4 files changed

+5
-2
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`65`: support SlidingWindowCache
78
* :pr:`63`: support option ``--trained``
89
* :pr:`61`: improves dynamic shapes for EncoderDecoderCache
910
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
("py:class", "transformers.cache_utils.DynamicCache"),
124124
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
125125
("py:class", "transformers.cache_utils.MambaCache"),
126+
("py:class", "transformers.cache_utils.SlidingWindowCache"),
126127
("py:class", "transformers.configuration_utils.PretrainedConfig"),
127128
("py:func", "torch.export._draft_export.draft_export"),
128129
("py:func", "torch._export.tools.report_exportability"),

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import torch
33
import transformers
4-
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.helpers.cache_helper import (
77
flatten_unflatten_for_dynamic_shapes,
@@ -134,6 +134,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
134134
self.string_type(c2, with_shape=True),
135135
)
136136

137+
@requires_transformers("4.51") # the structure changes
137138
def test_make_mamba_cache(self):
138139
cache = make_mamba_cache(
139140
[

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def validate_model(
346346

347347
for k in ["task", "size", "n_weights"]:
348348
summary[f"model_{k.replace('_','')}"] = data[k]
349-
summary["model_inputs_opionts"] = input_options or ""
349+
summary["model_inputs_opionts"] = str(input_options or "")
350350
summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
351351
summary["model_shapes"] = string_type(str(data["dynamic_shapes"]))
352352
summary["model_class"] = data["model"].__class__.__name__

0 commit comments

Comments
 (0)