Skip to content

Commit 15cd008

Browse files
committed
fix mambacache import
1 parent ef2d94b commit 15cd008

File tree

5 files changed

+30
-11
lines changed

5 files changed

+30
-11
lines changed

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ class TestOnnxExportErrors(ExtTestCase):
2222
def test_pytree_flatten_mamba_cache(self):
2323
import torch
2424
import torch.utils._pytree as py_pytree
25-
from transformers.cache_utils import MambaCache
25+
26+
try:
27+
from transformers.models.mamba.cache_mamba import MambaCache
28+
except ImportError:
29+
from transformers.cache_utils import MambaCache
2630

2731
class _config:
2832
def __init__(self):

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import transformers
55
import transformers.cache_utils
66

7+
try:
8+
from transformers.models.mamba.cache_mamba import MambaCache
9+
except ImportError:
10+
from transformers.cache_utils import MambaCache
11+
712

813
def flatten_unflatten_for_dynamic_shapes(
914
obj: Any,
@@ -242,10 +247,8 @@ def make_encoder_decoder_cache(
242247
)
243248

244249

245-
def make_mamba_cache(
246-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
247-
) -> transformers.cache_utils.MambaCache:
248-
"Creates a :class:`transformers.cache_utils.MambaCache`."
250+
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
251+
"Creates a ``MambaCache``."
249252
dtype = key_value_pairs[0][0].dtype
250253

251254
class _config:
@@ -256,7 +259,7 @@ def __init__(self):
256259
self.num_hidden_layers = len(key_value_pairs)
257260
self.dtype = dtype
258261

259-
cache = transformers.cache_utils.MambaCache(
262+
cache = MambaCache(
260263
_config(),
261264
max_batch_size=key_value_pairs[0][0].shape[0],
262265
device=key_value_pairs[0][0].device,
@@ -286,7 +289,7 @@ def __init__(self):
286289

287290
def make_sliding_window_cache(
288291
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
289-
) -> transformers.cache_utils.MambaCache:
292+
) -> transformers.cache_utils.SlidingWindowCache:
290293
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
291294

292295
class _config:

onnx_diagnostic/tasks/text_generation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Callable, Dict, Optional, Tuple, Union
22
import torch
3-
import transformers
43
from ..helpers.cache_helper import (
54
make_dynamic_cache,
65
make_mamba_cache,
@@ -95,9 +94,14 @@ def get_inputs(
9594
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
9695

9796
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
97+
try:
98+
from transformers.models.mamba.cache_mamba import MambaCache
99+
except ImportError:
100+
from transformers.cache_utils import MambaCache
101+
98102
assert cls_cache in (
99103
"MambaCache",
100-
transformers.cache_utils.MambaCache,
104+
MambaCache,
101105
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
102106
seq_length_multiple = 8
103107
sequence_length = (

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import transformers
77
from transformers.cache_utils import (
88
DynamicCache,
9-
MambaCache,
109
EncoderDecoderCache,
1110
SlidingWindowCache,
1211
StaticCache,
1312
)
1413

14+
try:
15+
from transformers.models.mamba.cache_mamba import MambaCache
16+
except ImportError:
17+
from transformers.cache_utils import MambaCache
18+
1519
from ..helpers import string_type
1620
from .serialization import _lower_name_with_
1721

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
import transformers
44
from transformers.cache_utils import (
55
DynamicCache,
6-
MambaCache,
76
EncoderDecoderCache,
87
SlidingWindowCache,
98
StaticCache,
109
)
10+
11+
try:
12+
from transformers.models.mamba.cache_mamba import MambaCache
13+
except ImportError:
14+
from transformers.cache_utils import MambaCache
1115
from transformers.modeling_outputs import BaseModelOutput
1216
from ...helpers.cache_helper import make_static_cache
1317
from . import make_serialization_function_for_dataclass

0 commit comments

Comments
 (0)