Skip to content

Commit b3acf09

Browse files
committed
support for transformers 4.53.0
1 parent 20b2013 commit b3acf09

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.10', '3.11', '3.12', '3.13']
19-
transformers: ['4.48.3', '4.51.3', '4.52.4', 'main']
19+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.0', 'main']
2020
torch: ['2.7', 'main']
2121
exclude:
2222
- python: '3.10'
@@ -28,7 +28,7 @@ jobs:
2828
- python: '3.10'
2929
transformers: 'main'
3030
- python: '3.11'
31-
transformers: '4.52.4'
31+
transformers: '4.53.0'
3232
- python: '3.11'
3333
transformers: 'main'
3434
- python: '3.13'

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.2
55
+++++
66

7+
* :pr:`168`, :pr:`169`: introduces patch_diffusers
78
* :pr:`166`: improves handling of StaticCache
89
* :pr:`165`: support for task text-to-image
910
* :pr:`162`: improves graphs rendering for historical data

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
ignore_warnings,
6+
requires_transformers,
7+
requires_pytorch,
8+
)
49
from onnx_diagnostic.torch_models.llms import get_phi2
510
from onnx_diagnostic.helpers import string_type
611

@@ -13,8 +18,10 @@ def test_get_phi2(self):
1318
model(**inputs)
1419

1520
@ignore_warnings(UserWarning)
16-
@requires_transformers("4.53")
21+
@requires_transformers("4.54")
22+
@requires_pytorch("2.9.99")
1723
def test_export_phi2_1(self):
24+
# exporting vmap does not work
1825
data = get_phi2(num_hidden_layers=2)
1926
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2027
self.assertEqual(

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def torch_export_patches(
182182
and show a stack trace indicating the exact location of the issue,
183183
``if stop_if_static > 1``, more methods are replace to catch more
184184
issues
185-
:param patch: if False, disable all patches except the registration of
186-
serialization function
185+
:param patch: if False, disable all patches but keeps the registration of
186+
serialization functions if other patch functions are enabled
187187
:param custom_patches: to apply custom patches,
188188
every patched class must define static attributes
189189
``_PATCHES_``, ``_PATCHED_CLASS_``
@@ -270,7 +270,11 @@ def torch_export_patches(
270270
pass
271271
elif not patch:
272272
fct_callable = lambda x: x # noqa: E731
273-
done = register_cache_serialization(verbose=verbose)
273+
done = register_cache_serialization(
274+
patch_transformers=patch_transformers,
275+
patch_diffusers=patch_diffusers,
276+
verbose=verbose,
277+
)
274278
try:
275279
yield fct_callable
276280
finally:

onnx_diagnostic/torch_models/validate.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,19 @@ def validate_model(
349349
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
350350
"""
351351
if isinstance(patch, bool):
352-
patch = dict(patch_transformers=True, patch_diffusers=True) if patch else {}
352+
patch = (
353+
dict(patch_transformers=True, patch_diffusers=True, patch=True)
354+
if patch
355+
else dict(patch=False)
356+
)
353357
elif isinstance(patch, str):
354-
patch = {p: True for p in patch.split(",")} # noqa: C420
358+
patch = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
355359
else:
356360
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
361+
patch = patch.copy()
362+
if "patch" not in patch:
363+
if any(patch.values):
364+
patch["patch"] = True
357365

358366
assert not rewrite or patch, (
359367
f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, "

0 commit comments

Comments
 (0)