Skip to content

Commit 1212c69

Browse files
committed
fix unittest
1 parent 2e11044 commit 1212c69

File tree

6 files changed

+30
-11
lines changed

6 files changed

+30
-11
lines changed

_unittests/ut_helpers/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def test_flatten_encoder_decoder_cache(self):
584584
self.assertIn("EncoderDecoderCache", s)
585585

586586
def test_string_typeçconfig(self):
587-
conf = get_pretrained_config("microsoft/phi-2")
587+
conf = get_pretrained_config("microsoft/phi-2", use_only_preinstalled=True)
588588
s = string_type(conf)
589589
self.assertStartsWith("PhiConfig(**{", s)
590590

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,16 @@ def test_task_from_id_long(self):
7272
@requires_torch("2.7")
7373
@hide_stdout()
7474
def test_get_pretrained_config(self):
75-
conf = get_pretrained_config("microsoft/phi-2")
75+
conf = get_pretrained_config("microsoft/phi-2", use_only_preinstalled=True)
7676
self.assertNotEmpty(conf)
7777

7878
@requires_transformers("4.50")
7979
@requires_torch("2.7")
8080
@hide_stdout()
8181
def test_get_pretrained_config_options(self):
82-
conf = get_pretrained_config("microsoft/phi-2", num_key_value_heads=16)
82+
conf = get_pretrained_config(
83+
"microsoft/phi-2", num_key_value_heads=16, use_only_preinstalled=True
84+
)
8385
self.assertNotEmpty(conf)
8486
self.assertEqual(conf.num_key_value_heads, 16)
8587

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def _diff(c1, c2):
129129
try:
130130
model(**inputs)
131131
except Exception as e:
132-
diff = _diff(get_pretrained_config(mid), data["configuration"])
132+
cf = get_pretrained_config(mid, use_only_preinstalled=True)
133+
diff = _diff(cf, data["configuration"])
133134
raise AssertionError(
134135
f"Computation failed due to {e}.\n--- pretrained\n"
135-
f"{pprint.pformat(get_pretrained_config(mid))}\n"
136-
f"--- modified\n{data['configuration']}\n"
136+
f"{pprint.pformat(cf)}\n--- modified\n{data['configuration']}\n"
137137
f"--- diff\n{diff}"
138138
) from e
139139
# different expected value for different version of transformers

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def test_validate_phi35_mini_instruct(self):
283283
@ignore_warnings(FutureWarning)
284284
@requires_transformers("4.51")
285285
def test_validate_phi35_4k_mini_instruct(self):
286-
mid = "microsoft/Phi-3.5-mini-4k-instruct"
286+
mid = "microsoft/Phi-3-mini-4k-instruct"
287287
summary, data = validate_model(
288288
mid,
289289
do_run=True,

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import json
44
import os
5+
import pprint
56
from typing import Any, Dict, List, Optional, Union
67
import transformers
78
from huggingface_hub import HfApi, model_info, hf_hub_download
@@ -33,10 +34,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
3334
return res
3435

3536

36-
def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.PretrainedConfig]:
37+
def get_cached_configuration(
38+
name: str, exc: bool = False, **kwargs
39+
) -> Optional[transformers.PretrainedConfig]:
3740
"""
3841
Returns cached configuration to avoid having to many accesses to internet.
3942
It returns None if not Cache. The list of cached models follows.
43+
If *exc* is True or if environment variable ``NOHTTP`` is defined,
44+
the function raises an exception if *name* is not found.
4045
4146
.. runpython::
4247
@@ -54,8 +59,9 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
5459
conf = copy.deepcopy(conf)
5560
update_config(conf, kwargs)
5661
return conf
57-
if os.environ.get("NOHTTP", ""):
58-
raise AssertionError(f"Unable to find {name!r} in {sorted(cached)}")
62+
assert not exc and not os.environ.get(
63+
"NOHTTP", ""
64+
), f"Unable to find {name!r} in {pprint.pformat(sorted(cached))}"
5965
return None
6066

6167

@@ -64,6 +70,7 @@ def get_pretrained_config(
6470
trust_remote_code: bool = True,
6571
use_preinstalled: bool = True,
6672
subfolder: Optional[str] = None,
73+
use_only_preinstalled: bool = False,
6774
**kwargs,
6875
) -> Any:
6976
"""
@@ -77,13 +84,20 @@ def get_pretrained_config(
7784
:func:`get_cached_configuration`, the cached list is mostly for
7885
unit tests
7986
:param subfolder: subfolder for the given model id
87+
:param use_only_preinstalled: if True, raises an exception if not preinstalled
8088
:param kwargs: additional kwargs
8189
:return: a configuration
8290
"""
8391
if use_preinstalled:
84-
conf = get_cached_configuration(model_id, subfolder=subfolder, **kwargs)
92+
conf = get_cached_configuration(
93+
model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
94+
)
8595
if conf is not None:
8696
return conf
97+
assert not use_only_preinstalled, (
98+
f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
99+
f"use_preinstalled={use_preinstalled!r}"
100+
)
87101
if subfolder:
88102
try:
89103
return transformers.AutoConfig.from_pretrained(

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_untrained_model_with_inputs(
2626
use_preinstalled: bool = True,
2727
add_second_input: bool = False,
2828
subfolder: Optional[str] = None,
29+
use_only_preinstalled: bool = True,
2930
) -> Dict[str, Any]:
3031
"""
3132
Gets a non initialized model similar to the original model
@@ -46,6 +47,7 @@ def get_untrained_model_with_inputs(
4647
:param add_second_input: provides a second inputs to check a model
4748
supports different shapes
4849
:param subfolder: subfolder to use for this model id
50+
:param use_only_preinstalled: use only preinstalled version
4951
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
5052
some necessary rewriting as well
5153
@@ -74,6 +76,7 @@ def get_untrained_model_with_inputs(
7476
config = get_pretrained_config(
7577
model_id,
7678
use_preinstalled=use_preinstalled,
79+
use_only_preinstalled=use_only_preinstalled,
7780
subfolder=subfolder,
7881
**(model_kwargs or {}),
7982
)

0 commit comments

Comments
 (0)