Skip to content

Commit f85793d

Browse files
committed
hub
1 parent 81890a1 commit f85793d

File tree

4 files changed

+106
-0
lines changed

4 files changed

+106
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
import pandas
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
4+
from onnx_diagnostic.torch_models.hghub.hub_api import enumerate_model_list
5+
6+
7+
class TestHuggingFaceHub(ExtTestCase):
8+
def test_enumerate_model_list(self):
9+
models = list(enumerate_model_list(2, verbose=1, dump="test_enumerate_model_list.csv"))
10+
self.assertEqual(len(models), 2)
11+
df = pandas.read_csv("test_enumerate_model_list.csv")
12+
self.assertEqual(df.shape, (2, 11))
13+
14+
@never_test()
15+
def test_hf_all_models(self):
16+
list(enumerate_model_list(-1, verbose=1, dump="test_hf_all_models.csv"))
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/hghub/__init__.py

Whitespace-only changes.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import List, Optional, Union
2+
from huggingface_hub import HfApi
3+
4+
5+
def enumerate_model_list(
6+
n: int = 50,
7+
task: Optional[str] = None,
8+
library: Optional[str] = None,
9+
tags: Optional[Union[str, List[str]]] = None,
10+
dump: Optional[str] = None,
11+
verbose: int = 0,
12+
):
13+
"""
14+
Enumerates models coming from :epkg:`huggingface_hub`.
15+
16+
:param n: number of models to retrieve (-1 for all)
17+
:param task: see :meth:`huggingface_hub.HfApi.list_models`
18+
:param tags: see :meth:`huggingface_hub.HfApi.list_models`
19+
:param library: see :meth:`huggingface_hub.HfApi.list_models`
20+
:param dump: dumps the result in this csv file
21+
:param verbose: show progress
22+
"""
23+
api = HfApi()
24+
models = api.list_models(task=task, library=library, tags=tags)
25+
seen = 0
26+
found = 0
27+
28+
if dump:
29+
with open(dump, "w") as f:
30+
f.write(
31+
",".join(
32+
[
33+
"id",
34+
"author",
35+
"created_at",
36+
"last_modified",
37+
"downloads",
38+
"downloads_all_time",
39+
"likes",
40+
"trending_score",
41+
"private",
42+
"gated",
43+
"tags",
44+
]
45+
)
46+
)
47+
f.write("\n")
48+
49+
for m in models:
50+
seen += 1 # noqa: SIM113
51+
if verbose and seen % 1000 == 0:
52+
print(f"[enumerate_model_list] {seen} models, found {found}")
53+
if verbose > 1:
54+
print(
55+
f"[enumerate_model_list] id={m.id!r}, "
56+
f"library={m.library_name!r}, task={m.task!r}"
57+
)
58+
with open(dump, "a") as f:
59+
f.write(
60+
",".join(
61+
map(
62+
str,
63+
[
64+
m.id,
65+
m.author or "",
66+
str(m.created_at or "").split(" ")[0],
67+
str(m.last_modified or "").split(" ")[0],
68+
m.downloads or "",
69+
m.downloads_all_time or "",
70+
m.likes or "",
71+
m.trending_score or "",
72+
m.private or "",
73+
m.gated or "",
74+
("|".join(m.tags)).replace(",", "_").replace(" ", "_"),
75+
],
76+
)
77+
)
78+
)
79+
f.write("\n")
80+
yield m
81+
found += 1 # noqa: SIM113
82+
if n >= 0:
83+
n -= 1
84+
if n == 0:
85+
break

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
black
22
diffusers>=0.30.0
33
furo
4+
huggingface_hub
45
matplotlib
56
onnx-array-api
67
git+https://github.com/microsoft/onnxscript.git

0 commit comments

Comments
 (0)