Skip to content

Commit c578deb

Browse files
committed
fix mypy
1 parent d167935 commit c578deb

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def make_dynamic_cache(
103103
)
104104
print(string_type(past_key_values, with_shape=True))
105105
"""
106-
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs))
106+
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
107107
for i, (key, value) in enumerate(key_value_pairs):
108108
cache.update(key, value, i)
109109
return cache

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import Dict, List, Optional, Union
2+
from typing import Any, Dict, List, Optional, Union
33
import transformers
44
from huggingface_hub import HfApi, model_info
55
from . import hub_data_cached_configs
@@ -38,7 +38,7 @@ def get_cached_configuration(name: str) -> Optional[transformers.PretrainedConfi
3838

3939
def get_pretrained_config(
4040
model_id: str, trust_remote_code: bool = True, use_cached: bool = True
41-
) -> str:
41+
) -> Any:
4242
"""
4343
Returns the config for a model_id.
4444
@@ -49,6 +49,7 @@ def get_pretrained_config(
4949
accessing the network, if available, it is returned by
5050
:func:`get_cached_configuration`, the cached list is mostly for
5151
unit tests
52+
:return: a configuration
5253
"""
5354
if use_cached:
5455
conf = get_cached_configuration(model_id)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,15 @@ def filter_inputs(
111111

112112
def _make_folder_name(
113113
model_id: str,
114-
exporter: str,
114+
exporter: Optional[str],
115115
optimization: Optional[str] = None,
116116
dtype: Optional[Union[str, torch.dtype]] = None,
117117
device: Optional[Union[str, torch.device]] = None,
118118
) -> str:
119119
"Creates a filename unique based on the given options."
120-
els = [model_id.replace("/", "_"), exporter]
120+
els = [model_id.replace("/", "_")]
121+
if exporter:
122+
els.append(exporter)
121123
if optimization:
122124
els.append(optimization)
123125
if dtype is not None and dtype:

0 commit comments

Comments
 (0)