Skip to content

Commit ce11d89

Browse files
committed
Download python files
1 parent 576e12d commit ce11d89

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import os
12
import unittest
23
import pandas
4+
import transformers
35
from onnx_diagnostic.ext_test_case import (
46
ExtTestCase,
57
hide_stdout,
@@ -13,6 +15,7 @@
1315
enumerate_model_list,
1416
get_model_info,
1517
get_pretrained_config,
18+
download_code_modelid,
1619
task_from_id,
1720
task_from_arch,
1821
task_from_tags,
@@ -147,6 +150,24 @@ def test__ccached_config_64(self):
147150
conf = _ccached_hf_internal_testing_tiny_random_beitforimageclassification()
148151
self.assertEqual(conf.auxiliary_channels, 256)
149152

153+
@requires_transformers("4.50")
154+
@requires_torch("2.7")
155+
@ignore_errors(OSError) # connectivity issues
156+
@hide_stdout()
157+
def test_download_code_modelid(self):
158+
model_id = "microsoft/Phi-3.5-MoE-instruct"
159+
files = download_code_modelid(model_id, verbose=1, add_path_to_sys_path=True)
160+
self.assertTrue(all(os.path.exists(f) for f in files))
161+
pyf = [os.path.split(name)[-1] for name in files]
162+
self.assertEqual(
163+
["configuration_phimoe.py", "modeling_phimoe.py", "sample_finetune.py"], pyf
164+
)
165+
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
166+
"modeling_phimoe.Phi4MMImageEmbedding",
167+
pretrained_model_name_or_path=os.path.split(files[0])[0],
168+
)
169+
self.assertEqual(cls.__name__, "Phi4MMImageEmbedding")
170+
150171

151172
if __name__ == "__main__":
152173
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import json
44
import os
55
import pprint
6+
import sys
67
from typing import Any, Dict, List, Optional, Union
78
import transformers
8-
from huggingface_hub import HfApi, model_info, hf_hub_download
9+
from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
910
from ...helpers.config_helper import update_config
1011
from . import hub_data_cached_configs
1112
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
@@ -327,3 +328,43 @@ def enumerate_model_list(
327328
n -= 1
328329
if n == 0:
329330
break
331+
332+
333+
def download_code_modelid(
334+
model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
335+
) -> List[str]:
336+
"""
337+
Downloads the code for a given model id.
338+
339+
:param model_id: model id
340+
:param verbose: verbosity
341+
:param add_path_to_sys_path: add folder where the files are downloaded to sys.path
342+
:return: list of downloaded files
343+
"""
344+
if verbose:
345+
print(f"[download_code_modelid] retrieve file list for {model_id!r}")
346+
files = list_repo_files(model_id)
347+
pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
348+
if verbose:
349+
print(f"[download_code_modelid] python files {pyfiles}")
350+
absfiles = []
351+
pathes = set()
352+
for i, name in enumerate(pyfiles):
353+
if verbose:
354+
print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
355+
r = hf_hub_download(repo_id=model_id, filename=name)
356+
p = os.path.split(r)[0]
357+
pathes.add(p)
358+
absfiles.append(r)
359+
if add_path_to_sys_path:
360+
for p in pathes:
361+
init = os.path.join(p, "__init__.py")
362+
if not os.path.exists(init):
363+
with open(init, "w"):
364+
pass
365+
if p in sys.path:
366+
continue
367+
if verbose:
368+
print(f"[download_code_modelid] add {p!r} to 'sys.path'")
369+
sys.path.insert(0, p)
370+
return absfiles

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import transformers
77
from ...helpers.config_helper import update_config
88
from ...tasks import reduce_model_config, random_input_kwargs
9-
from .hub_api import task_from_arch, task_from_id, get_pretrained_config
9+
from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid
1010

1111

1212
def _code_needing_rewriting(model: Any) -> Any:
@@ -149,7 +149,23 @@ def get_untrained_model_with_inputs(
149149
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
150150
else:
151151
if archs is not None:
152-
model = getattr(transformers, archs[0])(config)
152+
try:
153+
model = getattr(transformers, archs[0])(config)
154+
except AttributeError as e:
155+
# The code of the models is not in transformers but in the
156+
# repository of the model. We need to download it.
157+
pyfiles = download_code_modelid(model_id, verbose=verbose)
158+
if pyfiles:
159+
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
160+
archs[0], pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
161+
)
162+
model = cls(config)
163+
else:
164+
raise AttributeError(
165+
f"Unable to find class 'tranformers.{archs[0]}'. "
166+
f"The code needs to be downloaded, config="
167+
f"\n{pprint.pformat(config)}."
168+
) from e
153169
else:
154170
assert same_as_pretrained and use_pretrained, (
155171
f"Model {model_id!r} cannot be built, the model cannot be built. "

0 commit comments

Comments
 (0)