Skip to content

Commit 4ae0f2c

Browse files
committed
mypy
1 parent 675a01e commit 4ae0f2c

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.8
55
+++++
66

7+
* :pr:`375`: export a method to onnx in order to export using method generate
78
* :pr:`372`: fix patch on rotary embedding
89
* :pr:`371`: fix make_fake_with_dynamic_dimensions
910

onnx_diagnostic/export/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def __init__(
347347
self._call = (
348348
self._model_to_call if method_name == "forward" else getattr(mod, method_name)
349349
)
350-
self._inputs = []
350+
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
351351
self._convert_after_n_calls = convert_after_n_calls
352352
self._patch_kwargs = patch_kwargs
353353
self._method_src = None

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,10 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES
10001000
msg=lambda name=name: f" failing input {name!r}",
10011001
)
10021002
# reordering
1003-
kwargs = {p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs}
1003+
if kwargs is not None:
1004+
kwargs = {
1005+
p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1006+
}
10041007
return tuple(args), kwargs
10051008

10061009
def move_to_kwargs(
@@ -1063,8 +1066,14 @@ def move_to_kwargs(
10631066
f"and kwargs={set(kwargs)}, "
10641067
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
10651068
)
1066-
kwargs = {p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs}
1067-
kw_dyn = {p: kw_dyn[p] for p in self.forward_ordered_parameter_names if p in kw_dyn}
1069+
if kwargs is not None:
1070+
kwargs = {
1071+
p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1072+
}
1073+
if kw_dyn is not None:
1074+
kw_dyn = {
1075+
p: kw_dyn[p] for p in self.forward_ordered_parameter_names if p in kw_dyn
1076+
}
10681077
return args, kwargs, (tuple(), kw_dyn)
10691078

10701079
def validate_inputs_for_export(

0 commit comments

Comments
 (0)