File tree Expand file tree Collapse file tree 2 files changed +23
-2
lines changed
_unittests/ut_torch_models
onnx_diagnostic/torch_models/hghub Expand file tree Collapse file tree 2 files changed +23
-2
lines changed Original file line number Diff line number Diff line change 11import unittest
22import pandas
3- from onnx_diagnostic .ext_test_case import ExtTestCase , never_test
4- from onnx_diagnostic .torch_models .hghub .hub_api import enumerate_model_list
3+ from onnx_diagnostic .ext_test_case import (
4+ ExtTestCase ,
5+ never_test ,
6+ requires_torch ,
7+ requires_transformers ,
8+ )
9+ from onnx_diagnostic .torch_models .hghub .hub_api import enumerate_model_list , get_task
510
611
712class TestHuggingFaceHub (ExtTestCase ):
13+
14+ @requires_transformers ("4.50" ) # we limit to some versions of the CI
15+ @requires_torch ("2.7" )
816 def test_enumerate_model_list (self ):
917 models = list (enumerate_model_list (2 , verbose = 1 , dump = "test_enumerate_model_list.csv" ))
1018 self .assertEqual (len (models ), 2 )
1119 df = pandas .read_csv ("test_enumerate_model_list.csv" )
1220 self .assertEqual (df .shape , (2 , 11 ))
21+ tasks = [get_task (c ) for c in df .id ]
22+ self .assertEqual (['text-generation' , 'text-generation' ], tasks )
1323
1424 @never_test ()
1525 def test_hf_all_models (self ):
Original file line number Diff line number Diff line change 11from typing import List , Optional , Union
2+ import transformers
23from huggingface_hub import HfApi
34
45
6+ def get_task (model_id : str ) -> str :
7+ """
8+ Returns the task attached to a model id.
9+
10+ :param model_id: model id
11+ :return: task
12+ """
13+ return transformers .pipelines .get_task (model_id )
14+
15+
516def enumerate_model_list (
617 n : int = 50 ,
718 task : Optional [str ] = None ,
You can’t perform that action at this time.
0 commit comments