Skip to content

Commit 73f5f8f

Browse files
authored
include patch_diffusers=True when using --patch (#169)
* include patch_diffusers=True when using --patch * support for transformers 4.53.0 * fix mispelling * patch
1 parent 4dd2258 commit 73f5f8f

File tree

8 files changed

+59
-20
lines changed

8 files changed

+59
-20
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.10', '3.11', '3.12', '3.13']
19-
transformers: ['4.48.3', '4.51.3', '4.52.4', 'main']
19+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.0', 'main']
2020
torch: ['2.7', 'main']
2121
exclude:
2222
- python: '3.10'
@@ -28,7 +28,7 @@ jobs:
2828
- python: '3.10'
2929
transformers: 'main'
3030
- python: '3.11'
31-
transformers: '4.52.4'
31+
transformers: '4.53.0'
3232
- python: '3.11'
3333
transformers: 'main'
3434
- python: '3.13'

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.2
55
+++++
66

7+
* :pr:`168`, :pr:`169`: introduces patch_diffusers
78
* :pr:`166`: improves handling of StaticCache
89
* :pr:`165`: support for task text-to-image
910
* :pr:`162`: improves graphs rendering for historical data

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
ignore_warnings,
6+
requires_transformers,
7+
requires_torch,
8+
)
49
from onnx_diagnostic.torch_models.llms import get_phi2
510
from onnx_diagnostic.helpers import string_type
611

@@ -13,8 +18,10 @@ def test_get_phi2(self):
1318
model(**inputs)
1419

1520
@ignore_warnings(UserWarning)
16-
@requires_transformers("4.53")
21+
@requires_transformers("4.54")
22+
@requires_torch("2.9.99")
1723
def test_export_phi2_1(self):
24+
# exporting vmap does not work
1825
data = get_phi2(num_hidden_layers=2)
1926
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2027
self.assertEqual(

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def torch_export_patches(
182182
and show a stack trace indicating the exact location of the issue,
183183
``if stop_if_static > 1``, more methods are replace to catch more
184184
issues
185-
:param patch: if False, disable all patches except the registration of
186-
serialization function
185+
:param patch: if False, disable all patches but keeps the registration of
186+
serialization functions if other patch functions are enabled
187187
:param custom_patches: to apply custom patches,
188188
every patched class must define static attributes
189189
``_PATCHES_``, ``_PATCHED_CLASS_``
@@ -270,7 +270,11 @@ def torch_export_patches(
270270
pass
271271
elif not patch:
272272
fct_callable = lambda x: x # noqa: E731
273-
done = register_cache_serialization(verbose=verbose)
273+
done = register_cache_serialization(
274+
patch_transformers=patch_transformers,
275+
patch_diffusers=patch_diffusers,
276+
verbose=verbose,
277+
)
274278
try:
275279
yield fct_callable
276280
finally:

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,18 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
140140

141141
@functools.cache
142142
def task_from_arch(
143-
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
143+
arch: str,
144+
default_value: Optional[str] = None,
145+
model_id: Optional[str] = None,
146+
subfolder: Optional[str] = None,
144147
) -> str:
145148
"""
146149
This function relies on stored information. That information needs to be refresh.
147150
148151
:param arch: architecture name
149152
:param default_value: default value in case the task cannot be determined
150153
:param model_id: unused unless the architecture does not help.
154+
:param subfolder: subfolder
151155
:return: task
152156
153157
.. runpython::
@@ -162,7 +166,7 @@ def task_from_arch(
162166
data = load_architecture_task()
163167
if arch not in data and model_id:
164168
# Let's try with the model id.
165-
return task_from_id(model_id)
169+
return task_from_id(model_id, subfolder=subfolder)
166170
if default_value is not None:
167171
return data.get(arch, default_value)
168172
assert arch in data, (
@@ -178,6 +182,7 @@ def task_from_id(
178182
default_value: Optional[str] = None,
179183
pretrained: bool = False,
180184
fall_back_to_pretrained: bool = True,
185+
subfolder: Optional[str] = None,
181186
) -> str:
182187
"""
183188
Returns the task attached to a model id.
@@ -187,7 +192,7 @@ def task_from_id(
187192
if the task cannot be determined
188193
:param pretrained: uses the config
189194
:param fall_back_to_pretrained: falls back to pretrained config
190-
:param exc: raises an exception if True
195+
:param subfolder: subfolder
191196
:return: task
192197
"""
193198
if not pretrained:
@@ -196,7 +201,7 @@ def task_from_id(
196201
except RuntimeError:
197202
if not fall_back_to_pretrained:
198203
raise
199-
config = get_pretrained_config(model_id)
204+
config = get_pretrained_config(model_id, subfolder=subfolder)
200205
try:
201206
return config.pipeline_tag
202207
except AttributeError:
@@ -206,6 +211,8 @@ def task_from_id(
206211
data = load_architecture_task()
207212
if model_id in data:
208213
return data[model_id]
214+
if type(config) is dict and "_class_name" in config:
215+
return task_from_arch(config["_class_name"], default_value=default_value)
209216
if not config.architectures or not config.architectures:
210217
# Some hardcoded values until a better solution is found.
211218
if model_id.startswith("google/bert_"):

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BlenderbotModel,feature-extraction
2323
BloomModel,feature-extraction
2424
CLIPModel,zero-shot-image-classification
25+
CLIPTextModel,feature-extraction
2526
CLIPVisionModel,feature-extraction
2627
CamembertModel,feature-extraction
2728
CodeGenModel,feature-extraction

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_untrained_model_with_inputs(
106106
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
107107
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
108108
if task is None:
109-
task = task_from_arch(archs[0], model_id=model_id)
109+
task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder)
110110
if verbose:
111111
print(f"[get_untrained_model_with_inputs] task={task!r}")
112112

onnx_diagnostic/torch_models/validate.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def validate_model(
263263
use_pretrained: bool = False,
264264
optimization: Optional[str] = None,
265265
quiet: bool = False,
266-
patch: bool = False,
266+
patch: Union[bool, str, Dict[str, bool]] = False,
267267
rewrite: bool = False,
268268
stop_if_static: int = 1,
269269
dump_folder: Optional[str] = None,
@@ -301,8 +301,10 @@ def validate_model(
301301
:param optimization: optimization to apply to the exported model,
302302
depend on the the exporter
303303
:param quiet: if quiet, catches exception if any issue
304-
:param patch: applies patches (``patch_transformers=True``) before exporting,
305-
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
304+
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
305+
if True before exporting
306+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
307+
a string can be used to specify only one of them
306308
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
307309
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
308310
:param stop_if_static: stops if a dynamic dimension becomes static,
@@ -346,8 +348,24 @@ def validate_model(
346348
exported model returns the same outputs as the original one, otherwise,
347349
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
348350
"""
349-
assert not rewrite or patch, (
350-
f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, "
351+
if isinstance(patch, bool):
352+
patch_kwargs = (
353+
dict(patch_transformers=True, patch_diffusers=True, patch=True)
354+
if patch
355+
else dict(patch=False)
356+
)
357+
elif isinstance(patch, str):
358+
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
359+
else:
360+
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
361+
patch_kwargs = patch.copy()
362+
if "patch" not in patch_kwargs:
363+
if any(patch_kwargs.values()):
364+
patch_kwargs["patch"] = True
365+
366+
assert not rewrite or patch_kwargs.get("patch", False), (
367+
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
368+
f"patch must be True to enable rewriting, "
351369
f"if --no-patch was specified on the command line, --no-rewrite must be added."
352370
)
353371
summary = version_summary()
@@ -362,6 +380,7 @@ def validate_model(
362380
version_optimization=optimization or "",
363381
version_quiet=str(quiet),
364382
version_patch=str(patch),
383+
version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
365384
version_rewrite=str(rewrite),
366385
version_dump_folder=dump_folder or "",
367386
version_drop_inputs=str(list(drop_inputs or "")),
@@ -397,7 +416,7 @@ def validate_model(
397416
print(f"[validate_model] model_options={model_options!r}")
398417
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
399418
print(
400-
f"[validate_model] rewrite={rewrite}, patch={patch}, "
419+
f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
401420
f"stop_if_static={stop_if_static}"
402421
)
403422
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
@@ -573,18 +592,18 @@ def validate_model(
573592
f"[validate_model] -- export the model with {exporter!r}, "
574593
f"optimization={optimization!r}"
575594
)
576-
if patch:
595+
if patch_kwargs:
577596
if verbose:
578597
print(
579598
f"[validate_model] applies patches before exporting "
580599
f"stop_if_static={stop_if_static}"
581600
)
582601
with torch_export_patches( # type: ignore
583-
patch_transformers=True,
584602
stop_if_static=stop_if_static,
585603
verbose=max(0, verbose - 1),
586604
rewrite=data.get("rewrite", None),
587605
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
606+
**patch_kwargs, # type: ignore[arg-type]
588607
) as modificator:
589608
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
590609

0 commit comments

Comments
 (0)