Skip to content

Commit 244c5ef

Browse files
committed
fix import issues
1 parent 15cd008 commit 244c5ef

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_pytree_flatten_mamba_cache(self):
2424
import torch.utils._pytree as py_pytree
2525

2626
try:
27-
from transformers.models.mamba.cache_mamba import MambaCache
27+
from transformers.models.mamba.modeling_mamba import MambaCache
2828
except ImportError:
2929
from transformers.cache_utils import MambaCache
3030

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import transformers.cache_utils
66

77
try:
8-
from transformers.models.mamba.cache_mamba import MambaCache
8+
from transformers.models.mamba.modeling_mamba import MambaCache
99
except ImportError:
1010
from transformers.cache_utils import MambaCache
1111

onnx_diagnostic/tasks/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_inputs(
9595

9696
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
9797
try:
98-
from transformers.models.mamba.cache_mamba import MambaCache
98+
from transformers.models.mamba.modeling_mamba import MambaCache
9999
except ImportError:
100100
from transformers.cache_utils import MambaCache
101101

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
try:
15-
from transformers.models.mamba.cache_mamba import MambaCache
15+
from transformers.models.mamba.modeling_mamba import MambaCache
1616
except ImportError:
1717
from transformers.cache_utils import MambaCache
1818

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111
try:
12-
from transformers.models.mamba.cache_mamba import MambaCache
12+
from transformers.models.mamba.modeling_mamba import MambaCache
1313
except ImportError:
1414
from transformers.cache_utils import MambaCache
1515
from transformers.modeling_outputs import BaseModelOutput

0 commit comments

Comments
 (0)