Skip to content

Commit f2f1195

Browse files
authored
Add zero shot image classification (#52)
* Add zero shot * zero * simplifies patch * fix _make_causal
1 parent dc9d586 commit f2f1195

File tree

8 files changed

+309
-77
lines changed

8 files changed

+309
-77
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`52`: add support for zero-shot-image-classification
78
* :pr:`50`: add support for onnxruntime fusion
89
* :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny
910
* :pr:`45`: improve change_dynamic_dimension to fix some dimensions

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def test_get_untrained_model_with_inputs_codellama(self):
9696
# different expected value for different version of transformers
9797
self.assertIn((data["size"], data["n_weights"]), [(410532864, 102633216)])
9898

99+
@hide_stdout()
100+
@ignore_errors(OSError)
101+
def test_get_untrained_model_with_inputs_clip_vit(self):
102+
mid = "openai/clip-vit-base-patch16"
103+
data = get_untrained_model_with_inputs(mid, verbose=1)
104+
model, inputs = data["model"], data["inputs"]
105+
with bypass_export_some_errors(patch_transformers=True):
106+
model(**inputs)
107+
# different expected value for different version of transformers
108+
self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)])
109+
99110
@hide_stdout()
100111
def test_get_untrained_model_with_inputs_text2text_generation(self):
101112
mid = "sshleifer/tiny-marian-en-de"

_unittests/ut_torch_models/try_tasks.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,53 @@ def test_image_classification(self):
2525
outputs = model(**inputs)
2626
print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True))
2727

28+
@never_test()
29+
def test_image_classification_resnet(self):
30+
# clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k resnet
31+
32+
from transformers import ViTImageProcessor, ViTModel
33+
from PIL import Image
34+
import requests
35+
36+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
37+
image = Image.open(requests.get(url, stream=True).raw)
38+
39+
processor = ViTImageProcessor.from_pretrained("microsoft/resnet-50")
40+
model = ViTModel.from_pretrained("microsoft/resnet-50")
41+
inputs = processor(images=image, return_tensors="pt")
42+
print()
43+
print("-- inputs", string_type(inputs, with_shape=True, with_min_max=True))
44+
45+
outputs = model(**inputs)
46+
print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True))
47+
48+
@never_test()
49+
def test_zero_shot_image_classification(self):
50+
# clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k zero
51+
from PIL import Image
52+
import requests
53+
from transformers import CLIPProcessor, CLIPModel
54+
55+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
56+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
57+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
58+
image = Image.open(requests.get(url, stream=True).raw)
59+
inputs = processor(
60+
text=["a photo of a cat", "a photo of a dog"],
61+
images=[image, image],
62+
return_tensors="pt",
63+
padding=True,
64+
)
65+
print()
66+
print("-- inputs", string_type(inputs, with_shape=True, with_min_max=True))
67+
outputs = model(**inputs)
68+
print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True))
69+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
70+
probs = logits_per_image.softmax(
71+
dim=1
72+
) # we can take the softmax to get the label probabilities
73+
assert probs is not None
74+
2875
@never_test()
2976
def test_text2text_generation(self):
3077
# clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k text2t

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import sys
32
from dataclasses import dataclass
43
from typing import Any, Dict, List, Optional, Tuple
54
import torch
@@ -44,56 +43,47 @@ def _patch_make_causal_mask(
4443
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
4544

4645

47-
if sys.version_info[:2] <= (3, 11):
48-
49-
@dataclass
50-
class patched_AttentionMaskConverter:
51-
"""
52-
Patches
53-
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
54-
"""
55-
56-
_PATCHES_ = ["_make_causal_mask"]
57-
_PATCHED_CLASS_ = AttentionMaskConverter
58-
59-
@staticmethod
60-
def _make_causal_mask(
61-
input_ids_shape: torch.Size,
62-
dtype: torch.dtype,
63-
device: torch.device,
64-
past_key_values_length: int = 0,
65-
sliding_window: Optional[int] = None,
66-
):
67-
"""Patched method."""
68-
return _patch_make_causal_mask(
69-
input_ids_shape, dtype, device, past_key_values_length, sliding_window
70-
)
46+
@dataclass
47+
class patched_AttentionMaskConverter:
48+
"""
49+
Patches
50+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
51+
"""
7152

72-
else:
53+
_PATCHES_ = ["_make_causal_mask"]
54+
_PATCHED_CLASS_ = AttentionMaskConverter
7355

74-
@dataclass
75-
class patched_AttentionMaskConverter:
76-
"""
77-
Patches
78-
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
56+
@staticmethod
57+
def _make_causal_mask(
58+
*args,
59+
**kwargs,
60+
# input_ids_shape: torch.Size,
61+
# dtype: torch.dtype,
62+
# device: torch.device,
63+
# past_key_values_length: int = 0,
64+
# sliding_window: Optional[int] = None,
65+
):
7966
"""
67+
Patched method.
8068
81-
_PATCHES_ = ["_make_causal_mask"]
82-
_PATCHED_CLASS_ = AttentionMaskConverter
83-
84-
@staticmethod
85-
def _make_causal_mask(
86-
self,
87-
input_ids_shape: torch.Size,
88-
dtype: torch.dtype,
89-
device: torch.device,
90-
past_key_values_length: int = 0,
91-
sliding_window: Optional[int] = None,
92-
):
93-
"""Patched method."""
94-
return _patch_make_causal_mask(
95-
input_ids_shape, dtype, device, past_key_values_length, sliding_window
96-
)
69+
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70+
or ``self._make_causal_mask``. That changes this argument is receives.
71+
That should not matter but...
72+
"""
73+
if args:
74+
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
75+
names = [
76+
"input_ids_shape",
77+
"dtype",
78+
"device",
79+
"past_key_values_length",
80+
"sliding_window",
81+
]
82+
for i, a in enumerate(args):
83+
if i < index:
84+
continue
85+
kwargs[names[i - index]] = a
86+
return _patch_make_causal_mask(**kwargs)
9787

9888

9989
class patched_DynamicCache:

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,20 @@
44
import transformers
55
from huggingface_hub import HfApi, model_info
66
from . import hub_data_cached_configs
7-
from .hub_data import __date__, __data_tasks__, load_architecture_task
7+
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
8+
9+
10+
@functools.cache
11+
def get_architecture_default_values(architecture: str):
12+
"""
13+
The configuration may miss information to build the dummy inputs.
14+
This information returns the missing pieces.
15+
"""
16+
assert architecture in __data_arch_values__, (
17+
f"No known default values for {architecture!r}, "
18+
f"expecting one architecture in {', '.join(sorted(__data_arch_values__))}"
19+
)
20+
return __data_arch_values__[architecture]
821

922

1023
@functools.cache

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
__date__ = "2025-03-26"
77

8+
__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
9+
810
__data_arch__ = textwrap.dedent(
911
"""
1012
architecture,task
@@ -127,25 +129,25 @@
127129
)
128130

129131
__data_tasks__ = [
132+
"audio-classification",
130133
"automatic-speech-recognition",
131-
"image-text-to-text",
132-
"image-to-text",
133-
"text-generation",
134-
"object-detection",
135134
"document-question-answering",
136135
"feature-extraction",
137-
"text-to-audio",
138-
"zero-shot-image-classification",
136+
"fill-mask",
137+
"image-classification",
138+
"image-feature-extraction",
139139
"image-segmentation",
140-
"reinforcement-learning",
140+
"image-text-to-text",
141+
"image-to-text",
142+
"keypoint-detection",
143+
"mask-generation",
141144
"no-pipeline-tag",
142-
"image-classification",
145+
"object-detection",
146+
"reinforcement-learning",
147+
"text-generation",
148+
"text-to-audio",
143149
"text2text-generation",
144-
"mask-generation",
145-
"keypoint-detection",
146-
"audio-classification",
147-
"image-feature-extraction",
148-
"fill-mask",
150+
"zero-shot-image-classification",
149151
]
150152

151153
__models_testing__ = """

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3389,3 +3389,53 @@ def _ccached_openai_whisper_tiny():
33893389
"vocab_size": 51865,
33903390
}
33913391
)
3392+
3393+
3394+
def _ccached_openai_clip_vit_base_patch16():
3395+
"openai/clip-vit-base-patch16"
3396+
return transformers.CLIPConfig(
3397+
**{
3398+
"architectures": ["CLIPModel"],
3399+
"initializer_factor": 1.0,
3400+
"logit_scale_init_value": 2.6592,
3401+
"model_type": "clip",
3402+
"projection_dim": 512,
3403+
"text_config": {
3404+
"attention_dropout": 0.0,
3405+
"bos_token_id": 0,
3406+
"dropout": 0.0,
3407+
"eos_token_id": 2,
3408+
"hidden_act": "quick_gelu",
3409+
"hidden_size": 512,
3410+
"initializer_factor": 1.0,
3411+
"initializer_range": 0.02,
3412+
"intermediate_size": 2048,
3413+
"layer_norm_eps": 1e-05,
3414+
"max_position_embeddings": 77,
3415+
"model_type": "clip_text_model",
3416+
"num_attention_heads": 8,
3417+
"num_hidden_layers": 12,
3418+
"projection_dim": 512,
3419+
"vocab_size": 49408,
3420+
},
3421+
"torch_dtype": "float32",
3422+
"transformers_version": "4.52.0.dev0",
3423+
"vision_config": {
3424+
"attention_dropout": 0.0,
3425+
"dropout": 0.0,
3426+
"hidden_act": "quick_gelu",
3427+
"hidden_size": 768,
3428+
"image_size": 224,
3429+
"initializer_factor": 1.0,
3430+
"initializer_range": 0.02,
3431+
"intermediate_size": 3072,
3432+
"layer_norm_eps": 1e-05,
3433+
"model_type": "clip_vision_model",
3434+
"num_attention_heads": 12,
3435+
"num_channels": 3,
3436+
"num_hidden_layers": 12,
3437+
"patch_size": 16,
3438+
"projection_dim": 512,
3439+
},
3440+
}
3441+
)

0 commit comments

Comments
 (0)