Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.52.4', 'main']
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.0', 'main']
torch: ['2.7', 'main']
exclude:
- python: '3.10'
Expand All @@ -28,7 +28,7 @@ jobs:
- python: '3.10'
transformers: 'main'
- python: '3.11'
transformers: '4.52.4'
transformers: '4.53.0'
- python: '3.11'
transformers: 'main'
- python: '3.13'
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.2
+++++

* :pr:`168`, :pr:`169`: introduces patch_diffusers
* :pr:`166`: improves handling of StaticCache
* :pr:`165`: support for task text-to-image
* :pr:`162`: improves graphs rendering for historical data
Expand Down
11 changes: 9 additions & 2 deletions _unittests/ut_torch_models/test_llm_phi2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
ignore_warnings,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.torch_models.llms import get_phi2
from onnx_diagnostic.helpers import string_type

Expand All @@ -13,8 +18,10 @@ def test_get_phi2(self):
model(**inputs)

@ignore_warnings(UserWarning)
@requires_transformers("4.53")
@requires_transformers("4.54")
@requires_torch("2.9.99")
def test_export_phi2_1(self):
# exporting vmap does not work
data = get_phi2(num_hidden_layers=2)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
self.assertEqual(
Expand Down
10 changes: 7 additions & 3 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def torch_export_patches(
and show a stack trace indicating the exact location of the issue,
``if stop_if_static > 1``, more methods are replace to catch more
issues
:param patch: if False, disable all patches except the registration of
serialization function
:param patch: if False, disable all patches but keeps the registration of
serialization functions if other patch functions are enabled
:param custom_patches: to apply custom patches,
every patched class must define static attributes
``_PATCHES_``, ``_PATCHED_CLASS_``
Expand Down Expand Up @@ -270,7 +270,11 @@ def torch_export_patches(
pass
elif not patch:
fct_callable = lambda x: x # noqa: E731
done = register_cache_serialization(verbose=verbose)
done = register_cache_serialization(
patch_transformers=patch_transformers,
patch_diffusers=patch_diffusers,
verbose=verbose,
)
try:
yield fct_callable
finally:
Expand Down
15 changes: 11 additions & 4 deletions onnx_diagnostic/torch_models/hghub/hub_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,18 @@ def _guess_task_from_config(config: Any) -> Optional[str]:

@functools.cache
def task_from_arch(
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
arch: str,
default_value: Optional[str] = None,
model_id: Optional[str] = None,
subfolder: Optional[str] = None,
) -> str:
"""
This function relies on stored information. That information needs to be refresh.

:param arch: architecture name
:param default_value: default value in case the task cannot be determined
:param model_id: unused unless the architecture does not help.
:param subfolder: subfolder
:return: task

.. runpython::
Expand All @@ -162,7 +166,7 @@ def task_from_arch(
data = load_architecture_task()
if arch not in data and model_id:
# Let's try with the model id.
return task_from_id(model_id)
return task_from_id(model_id, subfolder=subfolder)
if default_value is not None:
return data.get(arch, default_value)
assert arch in data, (
Expand All @@ -178,6 +182,7 @@ def task_from_id(
default_value: Optional[str] = None,
pretrained: bool = False,
fall_back_to_pretrained: bool = True,
subfolder: Optional[str] = None,
) -> str:
"""
Returns the task attached to a model id.
Expand All @@ -187,7 +192,7 @@ def task_from_id(
if the task cannot be determined
:param pretrained: uses the config
:param fall_back_to_pretrained: falls back to pretrained config
:param exc: raises an exception if True
:param subfolder: subfolder
:return: task
"""
if not pretrained:
Expand All @@ -196,7 +201,7 @@ def task_from_id(
except RuntimeError:
if not fall_back_to_pretrained:
raise
config = get_pretrained_config(model_id)
config = get_pretrained_config(model_id, subfolder=subfolder)
try:
return config.pipeline_tag
except AttributeError:
Expand All @@ -206,6 +211,8 @@ def task_from_id(
data = load_architecture_task()
if model_id in data:
return data[model_id]
if type(config) is dict and "_class_name" in config:
return task_from_arch(config["_class_name"], default_value=default_value)
if not config.architectures or not config.architectures:
# Some hardcoded values until a better solution is found.
if model_id.startswith("google/bert_"):
Expand Down
1 change: 1 addition & 0 deletions onnx_diagnostic/torch_models/hghub/hub_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BlenderbotModel,feature-extraction
BloomModel,feature-extraction
CLIPModel,zero-shot-image-classification
CLIPTextModel,feature-extraction
CLIPVisionModel,feature-extraction
CamembertModel,feature-extraction
CodeGenModel,feature-extraction
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/torch_models/hghub/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_untrained_model_with_inputs(
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
if task is None:
task = task_from_arch(archs[0], model_id=model_id)
task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder)
if verbose:
print(f"[get_untrained_model_with_inputs] task={task!r}")

Expand Down
35 changes: 27 additions & 8 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def validate_model(
use_pretrained: bool = False,
optimization: Optional[str] = None,
quiet: bool = False,
patch: bool = False,
patch: Union[bool, str, Dict[str, bool]] = False,
rewrite: bool = False,
stop_if_static: int = 1,
dump_folder: Optional[str] = None,
Expand Down Expand Up @@ -301,8 +301,10 @@ def validate_model(
:param optimization: optimization to apply to the exported model,
depend on the the exporter
:param quiet: if quiet, catches exception if any issue
:param patch: applies patches (``patch_transformers=True``) before exporting,
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
if True before exporting
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
a string can be used to specify only one of them
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
:param stop_if_static: stops if a dynamic dimension becomes static,
Expand Down Expand Up @@ -346,8 +348,24 @@ def validate_model(
exported model returns the same outputs as the original one, otherwise,
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
"""
assert not rewrite or patch, (
f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, "
if isinstance(patch, bool):
patch_kwargs = (
dict(patch_transformers=True, patch_diffusers=True, patch=True)
if patch
else dict(patch=False)
)
elif isinstance(patch, str):
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
else:
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
patch_kwargs = patch.copy()
if "patch" not in patch_kwargs:
if any(patch_kwargs.values()):
patch_kwargs["patch"] = True

assert not rewrite or patch_kwargs.get("patch", False), (
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
f"patch must be True to enable rewriting, "
f"if --no-patch was specified on the command line, --no-rewrite must be added."
)
summary = version_summary()
Expand All @@ -362,6 +380,7 @@ def validate_model(
version_optimization=optimization or "",
version_quiet=str(quiet),
version_patch=str(patch),
version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
version_rewrite=str(rewrite),
version_dump_folder=dump_folder or "",
version_drop_inputs=str(list(drop_inputs or "")),
Expand Down Expand Up @@ -397,7 +416,7 @@ def validate_model(
print(f"[validate_model] model_options={model_options!r}")
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
print(
f"[validate_model] rewrite={rewrite}, patch={patch}, "
f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
f"stop_if_static={stop_if_static}"
)
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
Expand Down Expand Up @@ -573,18 +592,18 @@ def validate_model(
f"[validate_model] -- export the model with {exporter!r}, "
f"optimization={optimization!r}"
)
if patch:
if patch_kwargs:
if verbose:
print(
f"[validate_model] applies patches before exporting "
f"stop_if_static={stop_if_static}"
)
with torch_export_patches( # type: ignore
patch_transformers=True,
stop_if_static=stop_if_static,
verbose=max(0, verbose - 1),
rewrite=data.get("rewrite", None),
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
**patch_kwargs, # type: ignore[arg-type]
) as modificator:
data["inputs_export"] = modificator(data["inputs"]) # type: ignore

Expand Down
Loading