diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 2220e036..abedf352 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.4.0 +++++ +* :pr:`52`: add support for zero-shot-image-classification * :pr:`50`: add support for onnxruntime fusion * :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny * :pr:`45`: improve change_dynamic_dimension to fix some dimensions diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index e0b40310..700948b4 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -96,6 +96,17 @@ def test_get_untrained_model_with_inputs_codellama(self): # different expected value for different version of transformers self.assertIn((data["size"], data["n_weights"]), [(410532864, 102633216)]) + @hide_stdout() + @ignore_errors(OSError) + def test_get_untrained_model_with_inputs_clip_vit(self): + mid = "openai/clip-vit-base-patch16" + data = get_untrained_model_with_inputs(mid, verbose=1) + model, inputs = data["model"], data["inputs"] + with bypass_export_some_errors(patch_transformers=True): + model(**inputs) + # different expected value for different version of transformers + self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)]) + @hide_stdout() def test_get_untrained_model_with_inputs_text2text_generation(self): mid = "sshleifer/tiny-marian-en-de" diff --git a/_unittests/ut_torch_models/try_tasks.py b/_unittests/ut_torch_models/try_tasks.py index 05cdb04c..d0fdf22e 100644 --- a/_unittests/ut_torch_models/try_tasks.py +++ b/_unittests/ut_torch_models/try_tasks.py @@ -25,6 +25,53 @@ def test_image_classification(self): outputs = model(**inputs) print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True)) + @never_test() + def test_image_classification_resnet(self): + # clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k resnet + + from transformers import ViTImageProcessor, ViTModel + from PIL import Image + import requests + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + processor = ViTImageProcessor.from_pretrained("microsoft/resnet-50") + model = ViTModel.from_pretrained("microsoft/resnet-50") + inputs = processor(images=image, return_tensors="pt") + print() + print("-- inputs", string_type(inputs, with_shape=True, with_min_max=True)) + + outputs = model(**inputs) + print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True)) + + @never_test() + def test_zero_shot_image_classification(self): + # clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k zero + from PIL import Image + import requests + from transformers import CLIPProcessor, CLIPModel + + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") + processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = processor( + text=["a photo of a cat", "a photo of a dog"], + images=[image, image], + return_tensors="pt", + padding=True, + ) + print() + print("-- inputs", string_type(inputs, with_shape=True, with_min_max=True)) + outputs = model(**inputs) + print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True)) + logits_per_image = outputs.logits_per_image # this is the image-text similarity score + probs = logits_per_image.softmax( + dim=1 + ) # we can take the softmax to get the label probabilities + assert probs is not None + @never_test() def test_text2text_generation(self): # clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k text2t diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 0ae199f6..e45b211a 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1,5 +1,4 @@ import inspect -import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import torch @@ -44,56 +43,47 @@ def _patch_make_causal_mask( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) -if sys.version_info[:2] <= (3, 11): - - @dataclass - class patched_AttentionMaskConverter: - """ - Patches - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. - """ - - _PATCHES_ = ["_make_causal_mask"] - _PATCHED_CLASS_ = AttentionMaskConverter - - @staticmethod - def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """Patched method.""" - return _patch_make_causal_mask( - input_ids_shape, dtype, device, past_key_values_length, sliding_window - ) +@dataclass +class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ -else: + _PATCHES_ = ["_make_causal_mask"] + _PATCHED_CLASS_ = AttentionMaskConverter - @dataclass - class patched_AttentionMaskConverter: - """ - Patches - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + @staticmethod + def _make_causal_mask( + *args, + **kwargs, + # input_ids_shape: torch.Size, + # dtype: torch.dtype, + # device: torch.device, + # past_key_values_length: int = 0, + # sliding_window: Optional[int] = None, + ): """ + Patched method. - _PATCHES_ = ["_make_causal_mask"] - _PATCHED_CLASS_ = AttentionMaskConverter - - @staticmethod - def _make_causal_mask( - self, - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """Patched method.""" - return _patch_make_causal_mask( - input_ids_shape, dtype, device, past_key_values_length, sliding_window - ) + This static method may be called with ``AttentionMaskConverter._make_causal_mask`` + or ``self._make_causal_mask``. That changes this argument is receives. + That should not matter but... + """ + if args: + index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1 + names = [ + "input_ids_shape", + "dtype", + "device", + "past_key_values_length", + "sliding_window", + ] + for i, a in enumerate(args): + if i < index: + continue + kwargs[names[i - index]] = a + return _patch_make_causal_mask(**kwargs) class patched_DynamicCache: diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index 9f7de9c6..a70c4e30 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -4,7 +4,20 @@ import transformers from huggingface_hub import HfApi, model_info from . import hub_data_cached_configs -from .hub_data import __date__, __data_tasks__, load_architecture_task +from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__ + + +@functools.cache +def get_architecture_default_values(architecture: str): + """ + The configuration may miss information to build the dummy inputs. + This information returns the missing pieces. + """ + assert architecture in __data_arch_values__, ( + f"No known default values for {architecture!r}, " + f"expecting one architecture in {', '.join(sorted(__data_arch_values__))}" + ) + return __data_arch_values__[architecture] @functools.cache diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index af141b6f..6acda29d 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -5,6 +5,8 @@ __date__ = "2025-03-26" +__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)} + __data_arch__ = textwrap.dedent( """ architecture,task @@ -127,25 +129,25 @@ ) __data_tasks__ = [ + "audio-classification", "automatic-speech-recognition", - "image-text-to-text", - "image-to-text", - "text-generation", - "object-detection", "document-question-answering", "feature-extraction", - "text-to-audio", - "zero-shot-image-classification", + "fill-mask", + "image-classification", + "image-feature-extraction", "image-segmentation", - "reinforcement-learning", + "image-text-to-text", + "image-to-text", + "keypoint-detection", + "mask-generation", "no-pipeline-tag", - "image-classification", + "object-detection", + "reinforcement-learning", + "text-generation", + "text-to-audio", "text2text-generation", - "mask-generation", - "keypoint-detection", - "audio-classification", - "image-feature-extraction", - "fill-mask", + "zero-shot-image-classification", ] __models_testing__ = """ diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 85d6cb56..3949c1c3 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -3389,3 +3389,53 @@ def _ccached_openai_whisper_tiny(): "vocab_size": 51865, } ) + + +def _ccached_openai_clip_vit_base_patch16(): + "openai/clip-vit-base-patch16" + return transformers.CLIPConfig( + **{ + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 512, + "text_config": { + "attention_dropout": 0.0, + "bos_token_id": 0, + "dropout": 0.0, + "eos_token_id": 2, + "hidden_act": "quick_gelu", + "hidden_size": 512, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 2048, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 77, + "model_type": "clip_text_model", + "num_attention_heads": 8, + "num_hidden_layers": 12, + "projection_dim": 512, + "vocab_size": 49408, + }, + "torch_dtype": "float32", + "transformers_version": "4.52.0.dev0", + "vision_config": { + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "model_type": "clip_vision_model", + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "patch_size": 16, + "projection_dim": 512, + }, + } + ) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 4aa143d8..969b4e95 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -6,7 +6,7 @@ import torch import transformers from ...helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache -from .hub_api import task_from_arch, get_pretrained_config +from .hub_api import task_from_arch, get_pretrained_config, get_architecture_default_values @functools.cache @@ -87,6 +87,20 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: else len(config.hidden_sizes) ) ) + elif task == "zero-shot-image-classification": + check_hasattr(config, "vision_config", "text_config") + check_hasattr(config.vision_config, "num_hidden_layers", "num_attention_heads") + check_hasattr(config.text_config, "num_hidden_layers", "num_attention_heads") + kwargs = dict( + vision_config=dict( + num_hidden_layers=min(2, config.vision_config.num_hidden_layers), + num_attention_heads=min(2, config.vision_config.num_attention_heads), + ), + text_config=dict( + num_hidden_layers=min(2, config.text_config.num_hidden_layers), + num_attention_heads=min(2, config.text_config.num_attention_heads), + ), + ) elif task == "text2text-generation": kwargs = {} if hasattr(config, "num_decoder_layers"): @@ -114,11 +128,22 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: else: raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") - for k, v in kwargs.items(): - setattr(config, k, v) + update_config(config, kwargs) return kwargs +def update_config(config: Any, mkwargs: Dict[str, Any]): + """Updates a configuration with different values.""" + for k, v in mkwargs.items(): + if isinstance(v, dict): + assert hasattr( + config, k + ), f"missing attribute {k!r} in config={config}, cannot update it with {v}" + update_config(getattr(config, k), v) + else: + setattr(config, k, v) + + def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]): """ Checks the confiugation has all the attributes in ``args``. @@ -127,7 +152,9 @@ def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]): for a in args: assert isinstance(a, (str, tuple)), f"unexpected type {type(a)} in {args!r}" if isinstance(a, str): - assert hasattr(config, a), f"Missing attribute {a!r} in\n{config}" + assert (isinstance(config, dict) and a in config) or hasattr( + config, a + ), f"Missing attribute {a!r} in\n{config}" elif isinstance(a, tuple): assert any( (isinstance(name, str) and hasattr(config, name)) @@ -228,12 +255,19 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl fct = get_inputs_for_text2text_generation # type: ignore elif task == "image-classification": if config is not None: - check_hasattr(config, "image_size", "num_channels") - if config is None or isinstance(config.image_size, int): + check_hasattr(config, ("image_size", "architectures"), "num_channels") + if config is not None: + if hasattr(config, "image_size"): + image_size = config.image_size + else: + assert config.architectures, f"empty architecture in {config}" + default_values = get_architecture_default_values(config.architectures[0]) + image_size = default_values["image_size"] + if config is None or isinstance(image_size, int): kwargs = dict( batch_size=2, - input_width=224 if config is None else config.image_size, - input_height=224 if config is None else config.image_size, + input_width=224 if config is None else image_size, + input_height=224 if config is None else image_size, input_channels=3 if config is None else config.num_channels, ) else: @@ -244,6 +278,23 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl input_channels=config.num_channels, ) fct = get_inputs_for_image_classification # type: ignore + elif task == "zero-shot-image-classification": + if config is not None: + check_hasattr(config, "vision_config", "text_config") + check_hasattr(config.vision_config, "image_size", "num_channels") + check_hasattr(config.text_config, "vocab_size") + kwargs = dict( + batch_size=2, + batch_size_image=3, + sequence_length=30, + dummy_max_token_id=( + 49408 if config is None else (config.text_config.vocab_size - 1) + ), + input_width=224 if config is None else config.vision_config.image_size, + input_height=224 if config is None else config.vision_config.image_size, + input_channels=3 if config is None else config.vision_config.num_channels, + ) + fct = get_inputs_for_zero_shot_image_classification # type: ignore elif task == "image-text-to-text": if config is not None: check_hasattr( @@ -313,7 +364,6 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl fct = get_inputs_for_speech_automatic_recognition # type: ignore else: raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") - return kwargs, fct @@ -391,14 +441,20 @@ def get_untrained_model_with_inputs( ) # updating the configuration - if not same_as_pretrained: - mkwargs = reduce_model_config(config, task) - else: - mkwargs = {} + + mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {} if model_kwargs: for k, v in model_kwargs.items(): - setattr(config, k, v) - mkwargs[k] = v + if isinstance(v, dict): + if k in mkwargs: + mkwargs[k].update(v) + else: + mkwargs[k] = v + else: + mkwargs[k] = v + if mkwargs: + update_config(config, mkwargs) + # input kwargs kwargs, fct = random_input_kwargs(config, task) if inputs_kwargs: @@ -463,7 +519,7 @@ def get_inputs_for_image_classification( :param model: model to get the missing information :param config: configuration used to generate the model :param batch_size: batch size - :param input_channel: input channel + :param input_channels: input channel :param input_width: input width :param input_height: input height :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` @@ -491,6 +547,67 @@ def get_inputs_for_image_classification( return dict(inputs=inputs, dynamic_shapes=shapes) +def get_inputs_for_zero_shot_image_classification( + model: torch.nn.Module, + config: Optional[Any], + dummy_max_token_id: int, + batch_size: int = 2, + sequence_length: int = 30, + input_width: int = 224, + input_height: int = 224, + input_channels: int = 3, + batch_size_image=3, + **kwargs, +): + """ + Generates inputs for task ``zero-short-image-classification``. + + :param model: model to get the missing information + :param config: configuration used to generate the model + :param dummy_max_token_id: vocabulary size + :param batch_size: batch size + :param sequence_length: sequence length + :param batch_size_image: number of images + :param input_channels: input channel + :param input_width: input width + :param input_height: input height + :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` + :return: dictionary + + # input_ids:T7s2x7 + # attention_mask:T7s2x7 + # pixel_values:T1s2x3x224x224 + """ + assert isinstance( + input_width, int + ), f"Unexpected type for input_width {type(input_width)}{config}" + assert isinstance( + input_width, int + ), f"Unexpected type for input_height {type(input_height)}{config}" + + batch = torch.export.Dim("batch", min=1, max=1024) + seq_length = torch.export.Dim("seq_length", min=1, max=4096) + shapes = { + "inputs_ids": {0: batch, 1: seq_length}, + "attention_mask": {0: batch, 1: seq_length}, + "pixel_values": { + 0: torch.export.Dim("batch_img", min=1, max=1024), + # 2: torch.export.Dim("width", min=1, max=4096), + # 3: torch.export.Dim("height", min=1, max=4096), + }, + } + inputs = dict( + input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to( + torch.int64 + ), + attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64), + pixel_values=torch.randn( + batch_size_image, input_channels, input_width, input_height + ).clamp(-1, 1), + ) + return dict(inputs=inputs, dynamic_shapes=shapes) + + def get_inputs_for_text_generation( model: torch.nn.Module, config: Optional[Any], @@ -885,4 +1002,5 @@ def get_get_inputs_function_for_tasks() -> Dict[str, Callable]: "image-text-to-text": get_inputs_for_image_text_to_text, "text-generation": get_inputs_for_text_generation, "text2text-generation": get_inputs_for_text2text_generation, + "zero-shot-image-classification": get_inputs_for_zero_shot_image_classification, }