Skip to content

Commit bef2af2

Browse files
committed
add tiny
1 parent aa6d224 commit bef2af2

File tree

11 files changed

+397
-54
lines changed

11 files changed

+397
-54
lines changed

_doc/api/torch_models/hghub/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ onnx_diagnostic.torch_models.hghub
77

88
hub_api
99
hub_data
10+
model_inputs
11+
12+
.. autofunction:: onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs
1013

1114
.. automodule:: onnx_diagnostic.torch_models.hghub
1215
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.torch_models.hghub.model_inputs
3+
===============================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.hghub.model_inputs
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: get_untrained_model_with_inputs

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,5 +232,6 @@
232232
"arnir0/Tiny-LLM": "https://huggingface.co/arnir0/Tiny-LLM",
233233
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2",
234234
"microsoft/Phi-3.5-mini-instruct": "https://huggingface.co/microsoft/Phi-3.5-mini-instruct",
235+
"microsoft/Phi-3.5-vision-instruct": "https://huggingface.co/microsoft/Phi-3.5-vision-instruct",
235236
}
236237
)

_unittests/ut_torch_models/test_hghub.py

Lines changed: 0 additions & 52 deletions
This file was deleted.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import unittest
2+
import pandas
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
never_test,
7+
requires_torch,
8+
requires_transformers,
9+
)
10+
from onnx_diagnostic.torch_models.hghub.hub_api import (
11+
enumerate_model_list,
12+
get_model_info,
13+
get_pretrained_config,
14+
task_from_id,
15+
task_from_arch,
16+
task_from_tags,
17+
)
18+
from onnx_diagnostic.torch_models.hghub.hub_data import load_architecture_task
19+
20+
21+
class TestHuggingFaceHubApi(ExtTestCase):
22+
23+
@requires_transformers("4.50") # we limit to some versions of the CI
24+
@requires_torch("2.7")
25+
def test_enumerate_model_list(self):
26+
models = list(
27+
enumerate_model_list(
28+
2,
29+
verbose=1,
30+
dump="test_enumerate_model_list.csv",
31+
filter="text-generation",
32+
library="transformers",
33+
)
34+
)
35+
self.assertEqual(len(models), 2)
36+
df = pandas.read_csv("test_enumerate_model_list.csv")
37+
self.assertEqual(df.shape, (2, 12))
38+
tasks = [task_from_id(c) for c in df.id]
39+
self.assertEqual(["text-generation", "text-generation"], tasks)
40+
41+
@requires_transformers("4.50")
42+
@requires_torch("2.7")
43+
def test_task_from_id(self):
44+
for name, etask in [
45+
("arnir0/Tiny-LLM", "text-generation"),
46+
("microsoft/phi-2", "text-generation"),
47+
("microsoft/Phi-3.5-mini-instruct", "text-generation"),
48+
("microsoft/Phi-3.5-vision-instruct", "text-generation"),
49+
]:
50+
with self.subTest(name=name, task=etask):
51+
task = task_from_id(name, True)
52+
self.assertEqual(etask, task)
53+
54+
@requires_transformers("4.50")
55+
@requires_torch("2.7")
56+
@hide_stdout()
57+
def test_get_pretrained_config(self):
58+
conf = get_pretrained_config("microsoft/phi-2")
59+
self.assertNotEmpty(conf)
60+
print(conf)
61+
62+
@requires_transformers("4.50")
63+
@requires_torch("2.7")
64+
@hide_stdout()
65+
def test_get_model_info(self):
66+
info = get_model_info("microsoft/phi-2")
67+
self.assertEqual(info.pipeline_tag, "text-generation")
68+
69+
info = get_model_info("microsoft/Phi-3.5-vision-instruct")
70+
self.assertEqual(info.pipeline_tag, "image-text-to-text")
71+
72+
info = get_model_info("microsoft/Phi-4-multimodal-instruct")
73+
self.assertEqual(info.pipeline_tag, "automatic-speech-recognition")
74+
75+
def test_task_from_arch(self):
76+
task = task_from_arch("LlamaForCausalLM")
77+
self.assertEqual("text-generation", task)
78+
79+
@never_test()
80+
def test_hf_all_models(self):
81+
list(enumerate_model_list(-1, verbose=1, dump="test_hf_all_models.csv"))
82+
83+
def test_load_architecture_task(self):
84+
data = load_architecture_task()
85+
print(set(data.values()))
86+
87+
def test_task_from_tags(self):
88+
_tags = [
89+
("text-generation|nlp|code|en|text-generation-inference", "text-generation"),
90+
(
91+
"text-generation|nlp|code|vision|image-text-to-text|conversational",
92+
"image-text-to-text",
93+
),
94+
(
95+
"text-generation|nlp|code|audio|automatic-speech-recognition|speech-summarization|speech-translation|visual-question-answering",
96+
"automatic-speech-recognition",
97+
),
98+
]
99+
for tags, etask in _tags:
100+
with self.subTest(tags=tags, task=etask):
101+
task = task_from_tags(tags, True)
102+
self.assertEqual(etask, task)
103+
104+
105+
if __name__ == "__main__":
106+
unittest.main(verbosity=2)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
import transformers
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
requires_torch,
7+
requires_transformers,
8+
)
9+
from onnx_diagnostic.torch_models.hghub.model_inputs import (
10+
config_class_from_architecture,
11+
get_untrained_model_with_inputs,
12+
)
13+
14+
15+
class TestHuggingFaceHubModel(ExtTestCase):
16+
@requires_transformers("4.50") # we limit to some versions of the CI
17+
@requires_torch("2.7")
18+
def test_config_class_from_architecture(self):
19+
config = config_class_from_architecture("LlamaForCausalLM")
20+
self.assertEqual(config, transformers.LlamaConfig)
21+
22+
@hide_stdout()
23+
def test_get_untrained_model_with_inputs(self):
24+
mid = "arnir0/Tiny-LLM"
25+
data = get_untrained_model_with_inputs(mid, verbose=1)
26+
model, inputs = data["model"], data["inputs"]
27+
model(**inputs)
28+
self.assertEqual(data["size"], 1858125824)
29+
self.assertEqual(data["n_weights"], 464531456)
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main(verbosity=2)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model_inputs import get_untrained_model_with_inputs

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
import functools
12
from typing import List, Optional, Union
23
import transformers
3-
from huggingface_hub import HfApi
4-
from .hub_data import __date__, load_architecture_task
4+
from huggingface_hub import HfApi, model_info
5+
from .hub_data import __date__, __data_tasks__, load_architecture_task
56

67

78
def get_pretrained_config(model_id) -> str:
89
"""Returns the config for a model_id."""
910
return transformers.AutoConfig.from_pretrained(model_id)
1011

1112

13+
def get_model_info(model_id) -> str:
14+
"""Returns the model info for a model_id."""
15+
return model_info(model_id)
16+
17+
18+
@functools.cache
1219
def task_from_arch(arch: str) -> str:
1320
"""
1421
This function relies on stored information. That information needs to be refresh.
@@ -51,6 +58,20 @@ def task_from_id(model_id: str, pretrained: bool = False) -> str:
5158
return transformers.pipelines.get_task(model_id)
5259

5360

61+
def task_from_tags(tags: Union[str, List[str]]) -> str:
62+
"""
63+
Guesses the task from the list of tags.
64+
If given by a string, ``|`` should be the separater.
65+
"""
66+
if isinstance(tags, str):
67+
tags = tags.split("|")
68+
stags = set(tags)
69+
for task in __data_tasks__:
70+
if task in stags:
71+
return task
72+
raise ValueError(f"Unable to guess the task from tags={tags!r}")
73+
74+
5475
def enumerate_model_list(
5576
n: int = 50,
5677
task: Optional[str] = None,
@@ -92,6 +113,7 @@ def enumerate_model_list(
92113
",".join(
93114
[
94115
"id",
116+
"model_name",
95117
"author",
96118
"created_at",
97119
"last_modified",
@@ -123,6 +145,7 @@ def enumerate_model_list(
123145
str,
124146
[
125147
m.id,
148+
getattr(m, "model_name", "") or "",
126149
m.author or "",
127150
str(m.created_at or "").split(" ")[0],
128151
str(m.last_modified or "").split(" ")[0],

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,27 @@
33

44
__date__ = "2025-03-26"
55

6+
__data_tasks__ = [
7+
"automatic-speech-recognition",
8+
"image-to-text",
9+
"text-generation",
10+
"object-detection",
11+
"document-question-answering",
12+
"feature-extraction",
13+
"text-to-audio",
14+
"zero-shot-image-classification",
15+
"image-segmentation",
16+
"reinforcement-learning",
17+
"no-pipeline-tag",
18+
"image-classification",
19+
"text2text-generation",
20+
"mask-generation",
21+
"keypoint-detection",
22+
"audio-classification",
23+
"image-feature-extraction",
24+
"fill-mask",
25+
]
26+
627
__data_arch__ = """
728
architecture,task
829
ASTModel,feature-extraction

0 commit comments

Comments
 (0)