Skip to content

Commit 1d2f851

Browse files
committed
Reduce dependency on transformers
1 parent fde0173 commit 1d2f851

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
108108

109109
def is_cache_dynamic_registered(fast: bool = False) -> bool:
110110
"""
111-
Tells class :class:`transformers.cache_utils.DynamicCache` can be
111+
Tells if class :class:`transformers.cache_utils.DynamicCache` can be
112112
serialized and deserialized. Only then, :func:`torch.export.export`
113113
can export a model.
114114

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import onnx
44
import torch
55
from .helper import string_type, flatten_object
6-
from .torch_helper import to_numpy
7-
from .cache_helper import is_cache_dynamic_registered
86

97

108
def name_type_to_onnx_dtype(name: str) -> int:
@@ -49,14 +47,16 @@ def make_feeds(
4947
assert (
5048
not check_flatten
5149
or not all(isinstance(obj, torch.Tensor) for obj in flat)
52-
or not is_cache_dynamic_registered(fast=True)
50+
# or not is_cache_dynamic_registered(fast=True)
5351
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
5452
), (
5553
f"Unexpected number of flattened objects, "
5654
f"{string_type(flat, with_shape=True)} != "
5755
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
5856
)
5957
if use_numpy:
58+
from .torch_helper import to_numpy
59+
6060
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
6161
names = (
6262
[i.name for i in proto.graph.input]

0 commit comments

Comments
 (0)