Skip to content

Commit 209be4b

Browse files
authored
Reduce dependency on transformers (#259)
* Reduce dependency on transformers * disable patch on idefics * fix example * spell * disable patch
1 parent fde0173 commit 209be4b

File tree

6 files changed

+9
-9
lines changed

6 files changed

+9
-9
lines changed

_doc/examples/plot_export_hub_codellama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
# It still requires patches to be exportable (control flow).
9999
# See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
100100

101-
with torch_export_patches(patch_transformers=True) as f:
101+
with torch_export_patches(patch_torch=False, patch_transformers=True) as f:
102102
ep = torch.export.export(
103103
model,
104104
(),

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def test_image_text_to_text_idefics(self):
2222
self.assertEqual(data["task"], "image-text-to-text")
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
25-
print("***", self.string_type(data["inputs2"], with_shape=True))
2625
model(**data["inputs2"])
27-
with torch_export_patches(patch_transformers=True, verbose=10):
26+
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
2827
torch.export.export(
2928
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
3029
)

_unittests/ut_tasks/test_tasks_mask_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_mask_generation(self):
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
2525
model(**data["inputs2"])
26-
with torch_export_patches(patch_transformers=True, verbose=1):
26+
with torch_export_patches(patch_torch=False, patch_transformers=True, verbose=1):
2727
torch.export.export(
2828
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2929
)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
108108

109109
def is_cache_dynamic_registered(fast: bool = False) -> bool:
110110
"""
111-
Tells class :class:`transformers.cache_utils.DynamicCache` can be
111+
Tells if class :class:`transformers.cache_utils.DynamicCache` can be
112112
serialized and deserialized. Only then, :func:`torch.export.export`
113113
can export a model.
114114

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import onnx
44
import torch
55
from .helper import string_type, flatten_object
6-
from .torch_helper import to_numpy
7-
from .cache_helper import is_cache_dynamic_registered
86

97

108
def name_type_to_onnx_dtype(name: str) -> int:
@@ -49,14 +47,16 @@ def make_feeds(
4947
assert (
5048
not check_flatten
5149
or not all(isinstance(obj, torch.Tensor) for obj in flat)
52-
or not is_cache_dynamic_registered(fast=True)
50+
# or not is_cache_dynamic_registered(fast=True)
5351
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
5452
), (
5553
f"Unexpected number of flattened objects, "
5654
f"{string_type(flat, with_shape=True)} != "
5755
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
5856
)
5957
if use_numpy:
58+
from .torch_helper import to_numpy
59+
6060
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
6161
names = (
6262
[i.name for i in proto.graph.input]

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,7 +1380,8 @@ def forward(self, x, seq_len=None):
13801380

13811381
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
13821382
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
1383-
freqs = torch.einsum("i,j->ij", t, inv_freq)
1383+
# freqs = torch.einsum("i,j->ij", t, inv_freq)
1384+
freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
13841385
emb = torch.cat((freqs, freqs), dim=-1)
13851386
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
13861387

0 commit comments

Comments
 (0)