Skip to content

Commit 3764234

Browse files
committed
bb
1 parent f85793d commit 3764234

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

_unittests/ut_torch_models/test_hghub.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
import unittest
22
import pandas
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
4-
from onnx_diagnostic.torch_models.hghub.hub_api import enumerate_model_list
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
never_test,
6+
requires_torch,
7+
requires_transformers,
8+
)
9+
from onnx_diagnostic.torch_models.hghub.hub_api import enumerate_model_list, get_task
510

611

712
class TestHuggingFaceHub(ExtTestCase):
13+
14+
@requires_transformers("4.50") # we limit to some versions of the CI
15+
@requires_torch("2.7")
816
def test_enumerate_model_list(self):
917
models = list(enumerate_model_list(2, verbose=1, dump="test_enumerate_model_list.csv"))
1018
self.assertEqual(len(models), 2)
1119
df = pandas.read_csv("test_enumerate_model_list.csv")
1220
self.assertEqual(df.shape, (2, 11))
21+
tasks = [get_task(c) for c in df.id]
22+
self.assertEqual(['text-generation', 'text-generation'], tasks)
1323

1424
@never_test()
1525
def test_hf_all_models(self):

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
from typing import List, Optional, Union
2+
import transformers
23
from huggingface_hub import HfApi
34

45

6+
def get_task(model_id: str) -> str:
7+
"""
8+
Returns the task attached to a model id.
9+
10+
:param model_id: model id
11+
:return: task
12+
"""
13+
return transformers.pipelines.get_task(model_id)
14+
15+
516
def enumerate_model_list(
617
n: int = 50,
718
task: Optional[str] = None,

0 commit comments

Comments
 (0)