File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
108108
109109def is_cache_dynamic_registered (fast : bool = False ) -> bool :
110110 """
111- Tells class :class:`transformers.cache_utils.DynamicCache` can be
111+ Tells if class :class:`transformers.cache_utils.DynamicCache` can be
112112 serialized and deserialized. Only then, :func:`torch.export.export`
113113 can export a model.
114114
Original file line number Diff line number Diff line change 33import onnx
44import torch
55from .helper import string_type , flatten_object
6- from .torch_helper import to_numpy
7- from .cache_helper import is_cache_dynamic_registered
86
97
108def name_type_to_onnx_dtype (name : str ) -> int :
@@ -49,14 +47,16 @@ def make_feeds(
4947 assert (
5048 not check_flatten
5149 or not all (isinstance (obj , torch .Tensor ) for obj in flat )
52- or not is_cache_dynamic_registered (fast = True )
50+ # or not is_cache_dynamic_registered(fast=True)
5351 or len (flat ) == len (torch .utils ._pytree .tree_flatten (inputs )[0 ])
5452 ), (
5553 f"Unexpected number of flattened objects, "
5654 f"{ string_type (flat , with_shape = True )} != "
5755 f"{ string_type (torch .utils ._pytree .tree_flatten (inputs )[0 ], with_shape = True )} "
5856 )
5957 if use_numpy :
58+ from .torch_helper import to_numpy
59+
6060 flat = [to_numpy (t ) if isinstance (t , torch .Tensor ) else t for t in flat ]
6161 names = (
6262 [i .name for i in proto .graph .input ]
You can’t perform that action at this time.
0 commit comments