Skip to content

Commit 1b7f5fa

Browse files
committed
data
1 parent ffd76fe commit 1b7f5fa

File tree

7 files changed

+242
-5
lines changed

7 files changed

+242
-5
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.hghub.hub_api
3+
==========================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.hghub.hub_api
6+
:members:
7+
:no-undoc-members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.hghub.hub_data
3+
===========================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.hghub.hub_data
6+
:members:
7+
:no-undoc-members:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
onnx_diagnostic.torch_models.hghub
2+
==================================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
:caption: submodules
7+
8+
hub_api
9+
hub_data
10+
11+
.. automodule:: onnx_diagnostic.torch_models.hghub
12+
:members:
13+
:no-undoc-members:

_doc/api/torch_models/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ onnx_diagnostic.torch_models
55
:maxdepth: 1
66
:caption: submodules
77

8+
hghub/index
89
llms
910

1011
.. automodule:: onnx_diagnostic.torch_models

_unittests/ut_torch_models/test_hghub.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,43 @@
66
requires_torch,
77
requires_transformers,
88
)
9-
from onnx_diagnostic.torch_models.hghub.hub_api import enumerate_model_list, get_task
9+
from onnx_diagnostic.torch_models.hghub.hub_api import (
10+
enumerate_model_list,
11+
task_from_id,
12+
task_from_arch,
13+
)
1014

1115

1216
class TestHuggingFaceHub(ExtTestCase):
1317

1418
@requires_transformers("4.50") # we limit to some versions of the CI
1519
@requires_torch("2.7")
1620
def test_enumerate_model_list(self):
17-
models = list(enumerate_model_list(2, verbose=1, dump="test_enumerate_model_list.csv"))
21+
models = list(
22+
enumerate_model_list(
23+
2,
24+
verbose=1,
25+
dump="test_enumerate_model_list.csv",
26+
filter="text-generation",
27+
library="transformers",
28+
)
29+
)
1830
self.assertEqual(len(models), 2)
1931
df = pandas.read_csv("test_enumerate_model_list.csv")
2032
self.assertEqual(df.shape, (2, 11))
21-
tasks = [get_task(c) for c in df.id]
33+
tasks = [task_from_id(c) for c in df.id]
2234
self.assertEqual(["text-generation", "text-generation"], tasks)
2335

36+
@requires_transformers("4.50")
37+
@requires_torch("2.7")
38+
def test_task_from_id(self):
39+
task = task_from_id("arnir0/Tiny-LLM", True)
40+
self.assertEqual("text-generation", task)
41+
42+
def test_task_from_arch(self):
43+
task = task_from_arch("LlamaForCausalLM")
44+
self.assertEqual("text-generation", task)
45+
2446
@never_test()
2547
def test_hf_all_models(self):
2648
list(enumerate_model_list(-1, verbose=1, dump="test_hf_all_models.csv"))

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,53 @@
11
from typing import List, Optional, Union
22
import transformers
33
from huggingface_hub import HfApi
4+
from .hub_data import __date__, load_architecture_task
45

56

6-
def get_task(model_id: str) -> str:
7+
def get_pretrained_config(model_id) -> str:
8+
"""Returns the config for a model_id."""
9+
return transformers.AutoConfig.from_pretrained(model_id)
10+
11+
12+
def task_from_arch(arch: str) -> str:
13+
"""
14+
This function relies on stored information. That information needs to be refresh.
15+
16+
:param arch: architecture name
17+
:return: task
18+
19+
.. runpython::
20+
21+
from onnx_diagnostic.torch_models.hub_data import __date__
22+
print("last refresh", __date__)
23+
24+
List of supported architecturs, see
25+
:func:`load_architecture_task
26+
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
27+
"""
28+
data = load_architecture_task()
29+
assert arch in data, f"Architecture {arch!r} is unknown, last refresh in {__date__}"
30+
return data[arch]
31+
32+
33+
def task_from_id(model_id: str, pretrained: bool = False) -> str:
734
"""
835
Returns the task attached to a model id.
936
1037
:param model_id: model id
38+
:param pretrained: uses the config
1139
:return: task
1240
"""
41+
if pretrained:
42+
config = get_pretrained_config(model_id)
43+
try:
44+
return config.pipeline_tag
45+
except AttributeError:
46+
assert config.architectures is not None and len(config.architectures) == 1, (
47+
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
48+
f"architectures={config.architectures} in config={config}"
49+
)
50+
return task_from_arch(config.architectures[0])
1351
return transformers.pipelines.get_task(model_id)
1452

1553

@@ -18,7 +56,9 @@ def enumerate_model_list(
1856
task: Optional[str] = None,
1957
library: Optional[str] = None,
2058
tags: Optional[Union[str, List[str]]] = None,
59+
search: Optional[str] = None,
2160
dump: Optional[str] = None,
61+
filter: Optional[str] = None,
2262
verbose: int = 0,
2363
):
2464
"""
@@ -28,11 +68,21 @@ def enumerate_model_list(
2868
:param task: see :meth:`huggingface_hub.HfApi.list_models`
2969
:param tags: see :meth:`huggingface_hub.HfApi.list_models`
3070
:param library: see :meth:`huggingface_hub.HfApi.list_models`
71+
:param search: see :meth:`huggingface_hub.HfApi.list_models`
72+
:param filter: see :meth:`huggingface_hub.HfApi.list_models`
3173
:param dump: dumps the result in this csv file
3274
:param verbose: show progress
3375
"""
3476
api = HfApi()
35-
models = api.list_models(task=task, library=library, tags=tags)
77+
models = api.list_models(
78+
task=task,
79+
library=library,
80+
tags=tags,
81+
search=search,
82+
full=True,
83+
filter=filter,
84+
limit=n if n > 0 else None,
85+
)
3686
seen = 0
3787
found = 0
3888

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import io
2+
import functools
3+
4+
__date__ = "2025-03-26"
5+
6+
__data_arch__ = """
7+
architecture,task
8+
ASTModel,feature-extraction
9+
AlbertModel,feature-extraction
10+
BeitForImageClassification,image-classification
11+
BigBirdModel,feature-extraction
12+
BlenderbotModel,feature-extraction
13+
BloomModel,feature-extraction
14+
CLIPModel,zero-shot-image-classification
15+
CLIPVisionModel,feature-extraction
16+
CamembertModel,feature-extraction
17+
CodeGenModel,feature-extraction
18+
ConvBertModel,feature-extraction
19+
ConvNextForImageClassification,image-classification
20+
ConvNextV2Model,image-feature-extraction
21+
CvtModel,feature-extraction
22+
DPTModel,image-feature-extraction
23+
Data2VecAudioModel,feature-extraction
24+
Data2VecTextModel,feature-extraction
25+
Data2VecVisionModel,image-feature-extraction
26+
DebertaModel,feature-extraction
27+
DebertaV2Model,feature-extraction
28+
DecisionTransformerModel,reinforcement-learning
29+
DeiTModel,image-feature-extraction
30+
DetrModel,image-feature-extraction
31+
Dinov2Model,image-feature-extraction
32+
DistilBertModel,feature-extraction
33+
DonutSwinModel,feature-extraction
34+
ElectraModel,feature-extraction
35+
EsmModel,feature-extraction
36+
GLPNModel,image-feature-extraction
37+
GPTBigCodeModel,feature-extraction
38+
GPTJModel,feature-extraction
39+
GPTNeoModel,feature-extraction
40+
GPTNeoXForCausalLM,text-generation
41+
GemmaForCausalLM,text-generation
42+
GraniteForCausalLM,text-generation
43+
GroupViTModel,feature-extraction
44+
HieraForImageClassification,image-classification
45+
HubertModel,feature-extraction
46+
IBertModel,feature-extraction
47+
ImageGPTModel,image-feature-extraction
48+
LayoutLMModel,feature-extraction
49+
LayoutLMv3Model,feature-extraction
50+
LevitModel,image-feature-extraction
51+
LiltModel,feature-extraction
52+
LlamaForCausalLM,text-generation
53+
LongT5Model,feature-extraction
54+
LongformerModel,feature-extraction
55+
MCTCTModel,feature-extraction
56+
MPNetModel,feature-extraction
57+
MT5Model,feature-extraction
58+
MarianMTModel,text2text-generation
59+
MarkupLMModel,feature-extraction
60+
MaskFormerForInstanceSegmentation,image-segmentation
61+
MegatronBertModel,feature-extraction
62+
MgpstrForSceneTextRecognition,feature-extraction
63+
MistralForCausalLM,text-generation
64+
MobileBertModel,feature-extraction
65+
MobileNetV1Model,image-feature-extraction
66+
MobileNetV2Model,image-feature-extraction
67+
MobileViTForImageClassification,image-classification
68+
ModernBertForMaskedLM,fill-mask
69+
MoonshineForConditionalGeneration,automatic-speech-recognition
70+
MptForCausalLM,text-generation
71+
MusicgenForConditionalGeneration,text-to-audio
72+
NystromformerModel,feature-extraction
73+
OPTModel,feature-extraction
74+
Olmo2ForCausalLM,text-generation
75+
OlmoForCausalLM,text-generation
76+
OwlViTModel,feature-extraction
77+
Owlv2Model,feature-extraction
78+
PatchTSMixerForPrediction,no-pipeline-tag
79+
PatchTSTForPrediction,no-pipeline-tag
80+
PegasusModel,feature-extraction
81+
Phi3ForCausalLM,text-generation
82+
PhiForCausalLM,text-generation
83+
Pix2StructForConditionalGeneration,image-to-text
84+
PoolFormerModel,image-feature-extraction
85+
PvtForImageClassification,image-classification
86+
Qwen2ForCausalLM,text-generation
87+
RTDetrForObjectDetection,object-detection
88+
RegNetModel,image-feature-extraction
89+
RemBertModel,feature-extraction
90+
ResNetForImageClassification,image-classification
91+
RoFormerModel,feature-extraction
92+
RobertaModel,feature-extraction
93+
RtDetrV2ForObjectDetection,object-detection
94+
SEWDModel,feature-extraction
95+
SEWModel,feature-extraction
96+
SamModel,mask-generation
97+
SegformerModel,image-feature-extraction
98+
SiglipModel,zero-shot-image-classification
99+
SiglipVisionModel,image-feature-extraction
100+
Speech2TextModel,feature-extraction
101+
SpeechT5ForTextToSpeech,text-to-audio
102+
SplinterModel,feature-extraction
103+
SqueezeBertModel,feature-extraction
104+
Swin2SRModel,image-feature-extraction
105+
SwinModel,image-feature-extraction
106+
Swinv2Model,image-feature-extraction
107+
TableTransformerModel,image-feature-extraction
108+
UniSpeechForSequenceClassification,audio-classification
109+
ViTForImageClassification,image-classification
110+
ViTMAEModel,image-feature-extraction
111+
ViTMSNForImageClassification,image-classification
112+
VisionEncoderDecoderModel,document-question-answering
113+
VitPoseForPoseEstimation,keypoint-detection
114+
VitsModel,text-to-audio
115+
Wav2Vec2ConformerForCTC,automatic-speech-recognition
116+
Wav2Vec2Model,feature-extraction
117+
WhisperForConditionalGeneration,no-pipeline-tag
118+
XLMModel,feature-extraction
119+
XLMRobertaForCausalLM,text-generation
120+
YolosForObjectDetection,object-detection
121+
YolosModel,image-feature-extraction
122+
"""
123+
124+
125+
@functools.cache
126+
def load_architecture_task():
127+
"""
128+
Returns a dictionary mapping architecture to task.
129+
130+
import pprint
131+
from onnx_diagnostic.torch_models.hghub.hub_data import load_architecture_task
132+
pprint.pprint(load_architecture_task())
133+
"""
134+
import pandas
135+
136+
df = pandas.read_csv(io.StringIO(__data_arch__))
137+
return dict(zip(list(df["architecture"]), list(df["task"])))

0 commit comments

Comments
 (0)