Skip to content

Commit 857ef10

Browse files
committed
fix api for huggingface hub
1 parent 1be0d26 commit 857ef10

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def test_enumerate_model_list(self):
3939
verbose=1,
4040
dump="test_enumerate_model_list.csv",
4141
filter="image-classification",
42-
library="transformers",
4342
)
4443
)
4544
self.assertEqual(len(models), 2)

onnx_diagnostic/helpers/log_helper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ def _to_images_bar(
285285
nn = df.shape[1] // n_cols
286286
nn += int(df.shape[1] % n_cols != 0)
287287
ratio = float(os.environ.get("FIGSIZEH", "1"))
288-
fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 3 * ratio))
288+
figsize = (6 * n_cols, nn * (2 + df.shape[0] / 15) * ratio)
289+
fig, axs = plt.subplots(nn, n_cols, figsize=figsize)
289290
pos = 0
290291
imgs = []
291292
for c in self._make_loop(df.columns, verbose):
@@ -332,10 +333,12 @@ def rotate_align(ax, angle=15, align="right"):
332333
n_cols = len(groups)
333334

334335
title_suffix = f"\n{title_suffix}" if title_suffix else ""
336+
ratio = float(os.environ.get("FIGSIZEH", "1"))
337+
figsize = (5 * n_cols, max(len(g) for g in groups) * (2 + df.shape[1] / 2) * ratio)
335338
fig, axs = plt.subplots(
336339
df.shape[1],
337340
n_cols,
338-
figsize=(5 * n_cols, max(len(g) for g in groups) * df.shape[1] / 2),
341+
figsize=figsize,
339342
sharex=True,
340343
sharey="row" if n_cols > 1 else False,
341344
)

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,31 +289,25 @@ def task_from_tags(tags: Union[str, List[str]]) -> str:
289289

290290
def enumerate_model_list(
291291
n: int = 50,
292-
task: Optional[str] = None,
293-
library: Optional[str] = None,
294-
tags: Optional[Union[str, List[str]]] = None,
292+
pipeline_tag: Optional[str] = None,
295293
search: Optional[str] = None,
296294
dump: Optional[str] = None,
297-
filter: Optional[str] = None,
295+
filter: Optional[Union[str, List[str]]] = None,
298296
verbose: int = 0,
299297
):
300298
"""
301299
Enumerates models coming from :epkg:`huggingface_hub`.
302300
303301
:param n: number of models to retrieve (-1 for all)
304-
:param task: see :meth:`huggingface_hub.HfApi.list_models`
305-
:param tags: see :meth:`huggingface_hub.HfApi.list_models`
306-
:param library: see :meth:`huggingface_hub.HfApi.list_models`
302+
:param pipeline_tag: see :meth:`huggingface_hub.HfApi.list_models`
307303
:param search: see :meth:`huggingface_hub.HfApi.list_models`
308304
:param filter: see :meth:`huggingface_hub.HfApi.list_models`
309305
:param dump: dumps the result in this csv file
310306
:param verbose: show progress
311307
"""
312308
api = HfApi()
313309
models = api.list_models(
314-
task=task,
315-
library=library,
316-
tags=tags,
310+
pipeline_tag=pipeline_tag,
317311
search=search,
318312
full=True,
319313
filter=filter,

0 commit comments

Comments
 (0)