Skip to content

Commit 38dd1a0

Browse files
committed
moves slow import to a location where it works better
1 parent 9c9d52f commit 38dd1a0

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

_scripts/compare_model_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
print("-- import onnx-diagnostic.helper")
2424
from onnx_diagnostic.helpers.helper import flatten_object, string_type, max_diff, string_diff
2525

26-
print("-- import onnx-diagnostic.torch_models")
26+
print("-- import onnx-diagnostic.torch_models.hghub")
2727
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2828

2929
print("-- done")

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
import transformers
55
import transformers.cache_utils
66

7-
try:
8-
from transformers.models.mamba.modeling_mamba import MambaCache
9-
except ImportError:
10-
from transformers.cache_utils import MambaCache
11-
127

138
class CacheKeyValue:
149
"""
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
354349
)
355350

356351

357-
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
352+
def make_mamba_cache(
353+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
354+
) -> "MambaCache": # noqa: F821
358355
"Creates a ``MambaCache``."
356+
# import is moved here because this part is slow.
357+
try:
358+
from transformers.models.mamba.modeling_mamba import MambaCache
359+
except ImportError:
360+
from transformers.cache_utils import MambaCache
359361
dtype = key_value_pairs[0][0].dtype
360362

361363
class _config:

0 commit comments

Comments
 (0)