Skip to content

Commit 8db31e2

Browse files
committed
return configuration as well
1 parent dd5d08f commit 8db31e2

File tree

5 files changed

+103
-38
lines changed

5 files changed

+103
-38
lines changed

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from onnx_diagnostic.ext_test_case import (
44
ExtTestCase,
55
hide_stdout,
6+
long_test,
67
never_test,
78
requires_torch,
89
requires_transformers,
@@ -15,7 +16,10 @@
1516
task_from_arch,
1617
task_from_tags,
1718
)
18-
from onnx_diagnostic.torch_models.hghub.hub_data import load_architecture_task
19+
from onnx_diagnostic.torch_models.hghub.hub_data import (
20+
load_architecture_task,
21+
load_models_testing,
22+
)
1923

2024

2125
class TestHuggingFaceHubApi(ExtTestCase):
@@ -111,6 +115,17 @@ def test_task_from_tags(self):
111115
task = task_from_tags(tags)
112116
self.assertEqual(etask, task)
113117

118+
def test_model_testings(self):
119+
models = load_models_testing()
120+
self.assertNotEmpty(models)
121+
122+
@long_test()
123+
def test_model_testings_and_architctures(self):
124+
models = load_models_testing()
125+
for mid in models:
126+
task = task_from_id(mid)
127+
self.assertNotEmpty(task)
128+
114129

115130
if __name__ == "__main__":
116131
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def test_config_class_from_architecture(self):
2323
def test_get_untrained_model_with_inputs_tiny_llm(self):
2424
mid = "arnir0/Tiny-LLM"
2525
data = get_untrained_model_with_inputs(mid, verbose=1)
26+
self.assertEqual(
27+
set(data),
28+
{"model", "inputs", "dynamic_shapes", "configuration", "size", "n_weights"},
29+
)
2630
model, inputs = data["model"], data["inputs"]
2731
model(**inputs)
2832
self.assertEqual((1858125824, 464531456), (data["size"], data["n_weights"]))

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,32 @@ def task_from_arch(arch: str) -> str:
3737
return data[arch]
3838

3939

40-
def task_from_id(model_id: str, pretrained: bool = False) -> str:
40+
def task_from_id(
41+
model_id: str, pretrained: bool = False, fall_back_to_pretrained: bool = True
42+
) -> str:
4143
"""
4244
Returns the task attached to a model id.
4345
4446
:param model_id: model id
4547
:param pretrained: uses the config
48+
:param fall_back_to_pretrained: balls back to pretrained config
4649
:return: task
4750
"""
48-
if pretrained:
49-
config = get_pretrained_config(model_id)
51+
if not pretrained:
5052
try:
51-
return config.pipeline_tag
52-
except AttributeError:
53-
assert config.architectures is not None and len(config.architectures) == 1, (
54-
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
55-
f"architectures={config.architectures} in config={config}"
56-
)
57-
return task_from_arch(config.architectures[0])
58-
return transformers.pipelines.get_task(model_id)
53+
transformers.pipelines.get_task(model_id)
54+
except RuntimeError:
55+
if not fall_back_to_pretrained:
56+
raise
57+
config = get_pretrained_config(model_id)
58+
try:
59+
return config.pipeline_tag
60+
except AttributeError:
61+
assert config.architectures is not None and len(config.architectures) == 1, (
62+
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
63+
f"architectures={config.architectures} in config={config}"
64+
)
65+
return task_from_arch(config.architectures[0])
5966

6067

6168
def task_from_tags(tags: Union[str, List[str]]) -> str:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,9 @@
11
import io
22
import functools
3+
from typing import Dict, List
34

45
__date__ = "2025-03-26"
56

6-
__data_tasks__ = [
7-
"automatic-speech-recognition",
8-
"image-text-to-text",
9-
"image-to-text",
10-
"text-generation",
11-
"object-detection",
12-
"document-question-answering",
13-
"feature-extraction",
14-
"text-to-audio",
15-
"zero-shot-image-classification",
16-
"image-segmentation",
17-
"reinforcement-learning",
18-
"no-pipeline-tag",
19-
"image-classification",
20-
"text2text-generation",
21-
"mask-generation",
22-
"keypoint-detection",
23-
"audio-classification",
24-
"image-feature-extraction",
25-
"fill-mask",
26-
]
27-
287
__data_arch__ = """
298
architecture,task
309
ASTModel,feature-extraction
@@ -143,9 +122,61 @@
143122
YolosModel,image-feature-extraction
144123
"""
145124

125+
__data_tasks__ = [
126+
"automatic-speech-recognition",
127+
"image-text-to-text",
128+
"image-to-text",
129+
"text-generation",
130+
"object-detection",
131+
"document-question-answering",
132+
"feature-extraction",
133+
"text-to-audio",
134+
"zero-shot-image-classification",
135+
"image-segmentation",
136+
"reinforcement-learning",
137+
"no-pipeline-tag",
138+
"image-classification",
139+
"text2text-generation",
140+
"mask-generation",
141+
"keypoint-detection",
142+
"audio-classification",
143+
"image-feature-extraction",
144+
"fill-mask",
145+
]
146+
147+
__models_testing__ = """
148+
hf-internal-testing/tiny-random-BeitForImageClassification
149+
hf-internal-testing/tiny-random-convnext
150+
fxmarty/tiny-random-GemmaForCausalLM
151+
hf-internal-testing/tiny-random-GPTNeoXForCausalLM
152+
hf-internal-testing/tiny-random-GraniteForCausalLM
153+
hf-internal-testing/tiny-random-HieraForImageClassification
154+
fxmarty/tiny-llama-fast-tokenizer
155+
sshleifer/tiny-marian-en-de
156+
hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation
157+
echarlaix/tiny-random-mistral
158+
hf-internal-testing/tiny-random-mobilevit
159+
hf-internal-testing/tiny-random-MoonshineForConditionalGeneration
160+
hf-internal-testing/tiny-random-OlmoForCausalLM
161+
hf-internal-testing/tiny-random-Olmo2ForCausalLM
162+
echarlaix/tiny-random-PhiForCausalLM
163+
Xenova/tiny-random-Phi3ForCausalLM
164+
fxmarty/pix2struct-tiny-random
165+
fxmarty/tiny-dummy-qwen2
166+
hf-internal-testing/tiny-random-ViTMSNForImageClassification
167+
hf-internal-testing/tiny-random-YolosModel
168+
hf-internal-testing/tiny-xlm-roberta
169+
"""
170+
171+
172+
@functools.cache
173+
def load_models_testing() -> List[str]:
174+
"""Returns model ids for testing."""
175+
return [_.strip() for _ in __models_testing__.split("\n") if _.strip()]
176+
146177

147178
@functools.cache
148-
def load_architecture_task():
179+
def load_architecture_task() -> Dict[str, str]:
149180
"""
150181
Returns a dictionary mapping architecture to task.
151182

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_untrained_model_with_inputs(
6262
:param model_kwargs: to change the model generation
6363
:param verbose: display found information
6464
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
65-
:return: dictionary
65+
:return: dictionary with a model, inputs, dynamic shapes, and the configuration
6666
6767
Example:
6868
@@ -79,6 +79,7 @@ def get_untrained_model_with_inputs(
7979
print("-- number of parameters:", data["n_weights"])
8080
print("-- inputs:", string_type(data["inputs"], with_shape=True))
8181
print("-- dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
82+
print("-- configuration:", pprint.pformat(data["configuration"]))
8283
"""
8384
if verbose:
8485
print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
@@ -139,7 +140,7 @@ def get_untrained_model_with_inputs(
139140
if inputs_kwargs:
140141
kwargs.update(inputs_kwargs)
141142

142-
return get_inputs_for_text_generation(model, **kwargs)
143+
return get_inputs_for_text_generation(model, config, **kwargs)
143144
raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.")
144145

145146

@@ -155,6 +156,7 @@ def compute_model_size(model: torch.nn.Module) -> Tuple[int, int]:
155156

156157
def get_inputs_for_text_generation(
157158
model: torch.nn.Module,
159+
config: Optional[Any],
158160
max_token_id: int,
159161
num_key_value_heads: int,
160162
num_hidden_layers: int,
@@ -167,6 +169,7 @@ def get_inputs_for_text_generation(
167169
):
168170
"""
169171
:param model: model to get the missing information
172+
:param config: configuration used to generate the model
170173
:param head_dim: last dimension of the cache
171174
:param batch_size: batch size
172175
:param sequence_length: sequence length
@@ -216,5 +219,10 @@ def get_inputs_for_text_generation(
216219
)
217220
sizes = compute_model_size(model)
218221
return dict(
219-
model=model, inputs=inputs, dynamic_shapes=shapes, size=sizes[0], n_weights=sizes[1]
222+
model=model,
223+
inputs=inputs,
224+
dynamic_shapes=shapes,
225+
size=sizes[0],
226+
n_weights=sizes[1],
227+
configuration=config,
220228
)

0 commit comments

Comments
 (0)