Skip to content

Commit 33b758a

Browse files
committed
mypy
1 parent aaa5372 commit 33b758a

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ def serialization_functions(
140140
) -> Dict[type, Callable[[int], bool]]:
141141
"""Returns the list of serialization functions."""
142142

143-
supported_classes = set()
144-
classes = {}
145-
all_functions = {}
143+
supported_classes: Set[type] = set()
144+
classes: Dict[type, Callable[[int], bool]] = {}
145+
all_functions: Dict[type, Optional[str]] = {}
146146

147147
if patch_transformers:
148148
from .serialization.transformers_impl import (

onnx_diagnostic/torch_export_patches/serialization/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def make_serialization_function_for_dataclass(
1616
``dataclasses.dataclass``.
1717
"""
1818

19-
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]:
19+
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
2020
"""Serializes a ``%s`` with python objects."""
2121
return list(obj.values()), list(obj.keys())
2222

2323
def flatten_with_keys_cls(
24-
obj: cls,
24+
obj: cls, # type: ignore[valid-type]
2525
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
2626
"""Serializes a ``%s`` with python objects with keys."""
2727
values, context = list(obj.values()), list(obj.keys())
@@ -31,7 +31,7 @@ def flatten_with_keys_cls(
3131

3232
def unflatten_cls(
3333
values: List[Any], context: torch.utils._pytree.Context, output_type=None
34-
) -> cls:
34+
) -> cls: # type: ignore[valid-type]
3535
"""Restores an instance of ``%s`` from python objects."""
3636
return cls(**dict(zip(context, values)))
3737

onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, Set
22

33
try:
44
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
@@ -14,15 +14,15 @@
1414
from . import make_serialization_function_for_dataclass
1515

1616

17-
def _make_wrong_registrations() -> Dict[str, Optional[str]]:
18-
res = {}
17+
def _make_wrong_registrations() -> Dict[type, Optional[str]]:
18+
res: Dict[type, Optional[str]] = {}
1919
for c in [UNet2DConditionOutput]:
2020
if c is not None:
2121
res[c] = None
2222
return res
2323

2424

25-
SUPPORTED_DATACLASSES = set()
25+
SUPPORTED_DATACLASSES: Set[type] = set()
2626
WRONG_REGISTRATIONS = _make_wrong_registrations()
2727

2828

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, Tuple
1+
from typing import Any, List, Set, Tuple
22
import torch
33
import transformers
44
from transformers.cache_utils import (
@@ -13,7 +13,7 @@
1313
from . import make_serialization_function_for_dataclass
1414

1515

16-
SUPPORTED_DATACLASSES = set()
16+
SUPPORTED_DATACLASSES: Set[type] = set()
1717
WRONG_REGISTRATIONS = {
1818
DynamicCache: "4.50",
1919
BaseModelOutput: None,

0 commit comments

Comments
 (0)