Skip to content

Commit 4aa5db4

Browse files
authored
Download python files for huggingface models (#159)
* Download python files * doc * api * Fix importing files
1 parent 576e12d commit 4aa5db4

File tree

4 files changed

+106
-3
lines changed

4 files changed

+106
-3
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.1
55
+++++
66

7+
* :pr:`159`: supports for models with custom code in huggingface
8+
* :pr:`158`: fix uses of pretrained version
79
* :pr:`156`, :pr:`157`: add plots and other options to deal with the unpredictable
810
* :pr:`155`: better aggregation of historical data
911
* :pr:`151`, :pr:`153`: adds command line ``agg``, class CubeLogsPerformance to produce timeseries

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import os
12
import unittest
23
import pandas
4+
import transformers
35
from onnx_diagnostic.ext_test_case import (
46
ExtTestCase,
57
hide_stdout,
@@ -13,6 +15,7 @@
1315
enumerate_model_list,
1416
get_model_info,
1517
get_pretrained_config,
18+
download_code_modelid,
1619
task_from_id,
1720
task_from_arch,
1821
task_from_tags,
@@ -147,6 +150,29 @@ def test__ccached_config_64(self):
147150
conf = _ccached_hf_internal_testing_tiny_random_beitforimageclassification()
148151
self.assertEqual(conf.auxiliary_channels, 256)
149152

153+
@requires_transformers("4.50")
154+
@requires_torch("2.7")
155+
@ignore_errors(OSError) # connectivity issues
156+
@hide_stdout()
157+
def test_download_code_modelid(self):
158+
model_id = "microsoft/Phi-3.5-MoE-instruct"
159+
files = download_code_modelid(model_id, verbose=1, add_path_to_sys_path=True)
160+
self.assertTrue(all(os.path.exists(f) for f in files))
161+
pyf = [os.path.split(name)[-1] for name in files]
162+
self.assertEqual(
163+
["configuration_phimoe.py", "modeling_phimoe.py", "sample_finetune.py"], pyf
164+
)
165+
try:
166+
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
167+
"modeling_phimoe.PhiMoERotaryEmbedding",
168+
pretrained_model_name_or_path=os.path.split(files[0])[0],
169+
)
170+
except ImportError as e:
171+
if "flash_attn" in str(e):
172+
raise unittest.SkipTest("missing package {e}")
173+
raise
174+
self.assertEqual(cls.__name__, "PhiMoERotaryEmbedding")
175+
150176

151177
if __name__ == "__main__":
152178
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import json
44
import os
55
import pprint
6+
import sys
67
from typing import Any, Dict, List, Optional, Union
78
import transformers
8-
from huggingface_hub import HfApi, model_info, hf_hub_download
9+
from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
910
from ...helpers.config_helper import update_config
1011
from . import hub_data_cached_configs
1112
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
@@ -327,3 +328,43 @@ def enumerate_model_list(
327328
n -= 1
328329
if n == 0:
329330
break
331+
332+
333+
def download_code_modelid(
334+
model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
335+
) -> List[str]:
336+
"""
337+
Downloads the code for a given model id.
338+
339+
:param model_id: model id
340+
:param verbose: verbosity
341+
:param add_path_to_sys_path: add folder where the files are downloaded to sys.path
342+
:return: list of downloaded files
343+
"""
344+
if verbose:
345+
print(f"[download_code_modelid] retrieve file list for {model_id!r}")
346+
files = list_repo_files(model_id)
347+
pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
348+
if verbose:
349+
print(f"[download_code_modelid] python files {pyfiles}")
350+
absfiles = []
351+
paths = set()
352+
for i, name in enumerate(pyfiles):
353+
if verbose:
354+
print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
355+
r = hf_hub_download(repo_id=model_id, filename=name)
356+
p = os.path.split(r)[0]
357+
paths.add(p)
358+
absfiles.append(r)
359+
if add_path_to_sys_path:
360+
for p in paths:
361+
init = os.path.join(p, "__init__.py")
362+
if not os.path.exists(init):
363+
with open(init, "w"):
364+
pass
365+
if p in sys.path:
366+
continue
367+
if verbose:
368+
print(f"[download_code_modelid] add {p!r} to 'sys.path'")
369+
sys.path.insert(0, p)
370+
return absfiles

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import transformers
77
from ...helpers.config_helper import update_config
88
from ...tasks import reduce_model_config, random_input_kwargs
9-
from .hub_api import task_from_arch, task_from_id, get_pretrained_config
9+
from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid
1010

1111

1212
def _code_needing_rewriting(model: Any) -> Any:
@@ -149,7 +149,41 @@ def get_untrained_model_with_inputs(
149149
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
150150
else:
151151
if archs is not None:
152-
model = getattr(transformers, archs[0])(config)
152+
try:
153+
model = getattr(transformers, archs[0])(config)
154+
except AttributeError as e:
155+
# The code of the models is not in transformers but in the
156+
# repository of the model. We need to download it.
157+
pyfiles = download_code_modelid(model_id, verbose=verbose)
158+
if pyfiles:
159+
if "." in archs[0]:
160+
cls_name = archs[0]
161+
else:
162+
modeling = [_ for _ in pyfiles if "/modeling_" in _]
163+
assert len(modeling) == 1, (
164+
f"Unable to guess the main file implemented class {archs[0]!r} "
165+
f"from {pyfiles}, found={modeling}."
166+
)
167+
last_name = os.path.splitext(os.path.split(modeling[0])[-1])[0]
168+
cls_name = f"{last_name}.{archs[0]}"
169+
if verbose:
170+
print(
171+
f"[get_untrained_model_with_inputs] custom code for {cls_name!r}"
172+
)
173+
print(
174+
f"[get_untrained_model_with_inputs] from folder "
175+
f"{os.path.split(pyfiles[0])[0]!r}"
176+
)
177+
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
178+
cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
179+
)
180+
model = cls(config)
181+
else:
182+
raise AttributeError(
183+
f"Unable to find class 'tranformers.{archs[0]}'. "
184+
f"The code needs to be downloaded, config="
185+
f"\n{pprint.pformat(config)}."
186+
) from e
153187
else:
154188
assert same_as_pretrained and use_pretrained, (
155189
f"Model {model_id!r} cannot be built, the model cannot be built. "

0 commit comments

Comments
 (0)