Skip to content

Commit c5d5466

Browse files
committed
Fix importing files
1 parent 19d350c commit c5d5466

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ def test_download_code_modelid(self):
164164
)
165165
try:
166166
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
167-
"modeling_phimoe.Phi4MMImageEmbedding",
167+
"modeling_phimoe.PhiMoERotaryEmbedding",
168168
pretrained_model_name_or_path=os.path.split(files[0])[0],
169169
)
170170
except ImportError as e:
171171
if "flash_attn" in str(e):
172172
raise unittest.SkipTest("missing package {e}")
173173
raise
174-
self.assertEqual(cls.__name__, "Phi4MMImageEmbedding")
174+
self.assertEqual(cls.__name__, "PhiMoERotaryEmbedding")
175175

176176

177177
if __name__ == "__main__":

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,26 @@ def get_untrained_model_with_inputs(
156156
# repository of the model. We need to download it.
157157
pyfiles = download_code_modelid(model_id, verbose=verbose)
158158
if pyfiles:
159+
if "." in archs[0]:
160+
cls_name = archs[0]
161+
else:
162+
modeling = [_ for _ in pyfiles if "/modeling_" in _]
163+
assert len(modeling) == 1, (
164+
f"Unable to guess the main file implemented class {archs[0]!r} "
165+
f"from {pyfiles}, found={modeling}."
166+
)
167+
last_name = os.path.splitext(os.path.split(modeling[0])[-1])[0]
168+
cls_name = f"{last_name}.{archs[0]}"
169+
if verbose:
170+
print(
171+
f"[get_untrained_model_with_inputs] custom code for {cls_name!r}"
172+
)
173+
print(
174+
f"[get_untrained_model_with_inputs] from folder "
175+
f"{os.path.split(pyfiles[0])[0]!r}"
176+
)
159177
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
160-
archs[0], pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
178+
cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
161179
)
162180
model = cls(config)
163181
else:

0 commit comments

Comments
 (0)