diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2fcec38e..a6772b1c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.11', '3.12', '3.13'] - transformers: ['4.48.3', '4.51.3', '4.52.4', 'main'] + transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.0', 'main'] torch: ['2.7', 'main'] exclude: - python: '3.10' @@ -28,7 +28,7 @@ jobs: - python: '3.10' transformers: 'main' - python: '3.11' - transformers: '4.52.4' + transformers: '4.53.0' - python: '3.11' transformers: 'main' - python: '3.13' diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 74be09f0..12192a1b 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.2 +++++ +* :pr:`168`, :pr:`169`: introduces patch_diffusers * :pr:`166`: improves handling of StaticCache * :pr:`165`: support for task text-to-image * :pr:`162`: improves graphs rendering for historical data diff --git a/_unittests/ut_torch_models/test_llm_phi2.py b/_unittests/ut_torch_models/test_llm_phi2.py index 090c9b6e..c05007a6 100644 --- a/_unittests/ut_torch_models/test_llm_phi2.py +++ b/_unittests/ut_torch_models/test_llm_phi2.py @@ -1,6 +1,11 @@ import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + ignore_warnings, + requires_transformers, + requires_torch, +) from onnx_diagnostic.torch_models.llms import get_phi2 from onnx_diagnostic.helpers import string_type @@ -13,8 +18,10 @@ def test_get_phi2(self): model(**inputs) @ignore_warnings(UserWarning) - @requires_transformers("4.53") + @requires_transformers("4.54") + @requires_torch("2.9.99") def test_export_phi2_1(self): + # exporting vmap does not work data = get_phi2(num_hidden_layers=2) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] self.assertEqual( diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 07617be2..058041e5 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -182,8 +182,8 @@ def torch_export_patches( and show a stack trace indicating the exact location of the issue, ``if stop_if_static > 1``, more methods are replace to catch more issues - :param patch: if False, disable all patches except the registration of - serialization function + :param patch: if False, disable all patches but keeps the registration of + serialization functions if other patch functions are enabled :param custom_patches: to apply custom patches, every patched class must define static attributes ``_PATCHES_``, ``_PATCHED_CLASS_`` @@ -270,7 +270,11 @@ def torch_export_patches( pass elif not patch: fct_callable = lambda x: x # noqa: E731 - done = register_cache_serialization(verbose=verbose) + done = register_cache_serialization( + patch_transformers=patch_transformers, + patch_diffusers=patch_diffusers, + verbose=verbose, + ) try: yield fct_callable finally: diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index cb450cb7..da05f7c0 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -140,7 +140,10 @@ def _guess_task_from_config(config: Any) -> Optional[str]: @functools.cache def task_from_arch( - arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None + arch: str, + default_value: Optional[str] = None, + model_id: Optional[str] = None, + subfolder: Optional[str] = None, ) -> str: """ This function relies on stored information. That information needs to be refresh. @@ -148,6 +151,7 @@ def task_from_arch( :param arch: architecture name :param default_value: default value in case the task cannot be determined :param model_id: unused unless the architecture does not help. + :param subfolder: subfolder :return: task .. runpython:: @@ -162,7 +166,7 @@ def task_from_arch( data = load_architecture_task() if arch not in data and model_id: # Let's try with the model id. - return task_from_id(model_id) + return task_from_id(model_id, subfolder=subfolder) if default_value is not None: return data.get(arch, default_value) assert arch in data, ( @@ -178,6 +182,7 @@ def task_from_id( default_value: Optional[str] = None, pretrained: bool = False, fall_back_to_pretrained: bool = True, + subfolder: Optional[str] = None, ) -> str: """ Returns the task attached to a model id. @@ -187,7 +192,7 @@ def task_from_id( if the task cannot be determined :param pretrained: uses the config :param fall_back_to_pretrained: falls back to pretrained config - :param exc: raises an exception if True + :param subfolder: subfolder :return: task """ if not pretrained: @@ -196,7 +201,7 @@ def task_from_id( except RuntimeError: if not fall_back_to_pretrained: raise - config = get_pretrained_config(model_id) + config = get_pretrained_config(model_id, subfolder=subfolder) try: return config.pipeline_tag except AttributeError: @@ -206,6 +211,8 @@ def task_from_id( data = load_architecture_task() if model_id in data: return data[model_id] + if type(config) is dict and "_class_name" in config: + return task_from_arch(config["_class_name"], default_value=default_value) if not config.architectures or not config.architectures: # Some hardcoded values until a better solution is found. if model_id.startswith("google/bert_"): diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index bdf79693..908cb22f 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -22,6 +22,7 @@ BlenderbotModel,feature-extraction BloomModel,feature-extraction CLIPModel,zero-shot-image-classification + CLIPTextModel,feature-extraction CLIPVisionModel,feature-extraction CamembertModel,feature-extraction CodeGenModel,feature-extraction diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 6db64040..1961e049 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -106,7 +106,7 @@ def get_untrained_model_with_inputs( print(f"[get_untrained_model_with_inputs] architectures={archs!r}") print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}") if task is None: - task = task_from_arch(archs[0], model_id=model_id) + task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder) if verbose: print(f"[get_untrained_model_with_inputs] task={task!r}") diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 6736359f..693de5c7 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -263,7 +263,7 @@ def validate_model( use_pretrained: bool = False, optimization: Optional[str] = None, quiet: bool = False, - patch: bool = False, + patch: Union[bool, str, Dict[str, bool]] = False, rewrite: bool = False, stop_if_static: int = 1, dump_folder: Optional[str] = None, @@ -301,8 +301,10 @@ def validate_model( :param optimization: optimization to apply to the exported model, depend on the the exporter :param quiet: if quiet, catches exception if any issue - :param patch: applies patches (``patch_transformers=True``) before exporting, - see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` + :param patch: applies patches (``patch_transformers=True, path_diffusers=True``) + if True before exporting + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`, + a string can be used to specify only one of them :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting, see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` :param stop_if_static: stops if a dynamic dimension becomes static, @@ -346,8 +348,24 @@ def validate_model( exported model returns the same outputs as the original one, otherwise, :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used. """ - assert not rewrite or patch, ( - f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, " + if isinstance(patch, bool): + patch_kwargs = ( + dict(patch_transformers=True, patch_diffusers=True, patch=True) + if patch + else dict(patch=False) + ) + elif isinstance(patch, str): + patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420 + else: + assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}" + patch_kwargs = patch.copy() + if "patch" not in patch_kwargs: + if any(patch_kwargs.values()): + patch_kwargs["patch"] = True + + assert not rewrite or patch_kwargs.get("patch", False), ( + f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} " + f"patch must be True to enable rewriting, " f"if --no-patch was specified on the command line, --no-rewrite must be added." ) summary = version_summary() @@ -362,6 +380,7 @@ def validate_model( version_optimization=optimization or "", version_quiet=str(quiet), version_patch=str(patch), + version_patch_kwargs=str(patch_kwargs).replace(" ", ""), version_rewrite=str(rewrite), version_dump_folder=dump_folder or "", version_drop_inputs=str(list(drop_inputs or "")), @@ -397,7 +416,7 @@ def validate_model( print(f"[validate_model] model_options={model_options!r}") print(f"[validate_model] get dummy inputs with input_options={input_options}...") print( - f"[validate_model] rewrite={rewrite}, patch={patch}, " + f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, " f"stop_if_static={stop_if_static}" ) print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}") @@ -573,18 +592,18 @@ def validate_model( f"[validate_model] -- export the model with {exporter!r}, " f"optimization={optimization!r}" ) - if patch: + if patch_kwargs: if verbose: print( f"[validate_model] applies patches before exporting " f"stop_if_static={stop_if_static}" ) with torch_export_patches( # type: ignore - patch_transformers=True, stop_if_static=stop_if_static, verbose=max(0, verbose - 1), rewrite=data.get("rewrite", None), dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None), + **patch_kwargs, # type: ignore[arg-type] ) as modificator: data["inputs_export"] = modificator(data["inputs"]) # type: ignore