Skip to content

Commit c699541

Browse files
committed
fix
1 parent b2e18fa commit c699541

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

_doc/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
source_suffix = ".rst"
3838
master_doc = "index"
3939
project = "onnx-diagnostic"
40-
copyright = "2023-2024"
40+
copyright = "2025"
4141
author = "Xavier Dupré"
4242
version = __version__
4343
release = __version__
@@ -123,9 +123,11 @@
123123
("py:class", "transformers.cache_utils.DynamicCache"),
124124
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
125125
("py:class", "transformers.cache_utils.MambaCache"),
126+
("py:class", "transformers.configuration_utils.PretrainedConfig"),
126127
("py:func", "torch.export._draft_export.draft_export"),
127128
("py:func", "torch._export.tools.report_exportability"),
128129
("py:meth", "huggingface_hub.HfApi.list_models"),
130+
("py:meth", "transformers.AutoConfig.from_pretrained"),
129131
("py:meth", "transformers.GenerationMixin.generate"),
130132
("py:meth", "unittests.TestCase.subTest"),
131133
]

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def get_cached_configuration(name: str) -> Optional[transformers.PretrainedConfi
2323
2424
.. runpython::
2525
26-
from onnx_diagnostic.torch_models.hghub.hug_api import _retrieve_cached_configurations
26+
import pprint
27+
from onnx_diagnostic.torch_models.hghub.hub_api import _retrieve_cached_configurations
2728
2829
configs = _retrieve_cached_configurations()
2930
pprint.pprint(sorted(configs))

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def validate_model(
6565
assert not trained, f"trained={trained} not supported yet"
6666
assert not dtype, f"dtype={dtype} not supported yet"
6767
assert not device, f"device={device} not supported yet"
68-
summary = {}
68+
summary: Dict[str, Union[int, float, str]] = {}
6969
if verbose:
7070
print(f"[validate_model] validate model id {model_id!r}")
7171
print("[validate_model] get dummy inputs...")
@@ -75,7 +75,8 @@ def validate_model(
7575
try:
7676
data = get_untrained_model_with_inputs(model_id, verbose=verbose, task=task)
7777
except Exception as e:
78-
summary["ERR_create"] = e
78+
summary["ERR_create"] = str(e)
79+
data["ERR_create"] = e
7980
summary["time_create"] = time.perf_counter() - begin
8081
return summary, {}
8182
else:
@@ -90,10 +91,10 @@ def validate_model(
9091
summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "")
9192
summary["model_id"] = model_id
9293
if verbose:
93-
print(f"[validate_model] task={data["task"]}")
94-
print(f"[validate_model] size={data["size"]}")
95-
print(f"[validate_model] n_weights={data["n_weights"]}")
96-
print(f"[validate_model] n_weights={data["n_weights"]}")
94+
print(f"[validate_model] task={data['task']}")
95+
print(f"[validate_model] size={data['size']}")
96+
print(f"[validate_model] n_weights={data['n_weights']}")
97+
print(f"[validate_model] n_weights={data['n_weights']}")
9798
for k, v in data["inputs"].items():
9899
print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}")
99100
for k, v in data["dynamic_shapes"].items():
@@ -106,7 +107,8 @@ def validate_model(
106107
try:
107108
expected = data["model"](**data["inputs"])
108109
except Exception as e:
109-
summary["ERR_run"] = e
110+
summary["ERR_run"] = str(e)
111+
data["ERR_run"] = e
110112
summary["time_run"] = time.perf_counter() - begin
111113
return summary, data
112114
else:

0 commit comments

Comments
 (0)