Skip to content

Commit b86c7bd

Browse files
committed
fix type
1 parent 85960d6 commit b86c7bd

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___
9494
self.nodes: List[NodeProto] = []
9595
self.opsets = {"": target_opset}
9696
self.ir_version = ir_version
97-
self.torch = torch
9897
self.sep = sep
9998

10099
def append_output_initializer(
@@ -163,7 +162,7 @@ def append_output_sequence(
163162
)
164163
else:
165164
assert all(
166-
isinstance(t, (np.ndarray, self.torch.Tensor)) for t in tensors
165+
isinstance(t, (np.ndarray, torch.Tensor)) for t in tensors
167166
), f"Nested sequences are not supported, types are {[type(t) for t in tensors]}"
168167
names = []
169168
for i, t in enumerate(tensors):
@@ -197,9 +196,7 @@ def append_output_dict(
197196
self.append_output_initializer(f"{name}{self.sep}keys", np.array(keys, dtype=np.str_))
198197
self.append_output_sequence(f"{name}{self.sep}values", values)
199198

200-
def _build_initializers(
201-
self, switch_low_high: bool
202-
) -> Tuple[List[TensorProto], Dict[str, TensorProto]]:
199+
def _build_initializers(self, switch_low_high: bool) -> List[TensorProto]:
203200
"""
204201
Builds initializers.
205202
@@ -209,7 +206,7 @@ def _build_initializers(
209206
init_dict = self.initializers_dict
210207
if switch_low_high:
211208
# Let's try to minimize the time.
212-
initializer = []
209+
initializer: List[TensorProto] = []
213210
for k, v in init_dict.items():
214211
if isinstance(v, TensorProto):
215212
initializer.append(v)
@@ -245,7 +242,7 @@ def _build_initializers(
245242
continue
246243
else:
247244
assert isinstance(
248-
v, self.torch.Tensor
245+
v, torch.Tensor
249246
), f"tensor {k!r} has un unexpected type {type(v)}"
250247
assert "FakeTensor" not in str(
251248
type(v)
@@ -272,9 +269,9 @@ def _build_initializers(
272269
if isinstance(v, TensorProto):
273270
res.append(v)
274271
continue
275-
if isinstance(v, self.torch.Tensor):
272+
if isinstance(v, torch.Tensor):
276273
# no string tensor
277-
t = self.from_array(v, name=k)
274+
t = proto_from_array(v, name=k)
278275
res.append(t)
279276
continue
280277
if isinstance(v, np.ndarray):
@@ -444,7 +441,7 @@ def _unflatten(
444441
return pos + 1, torch.from_numpy(outputs[pos]).to(device)
445442
raise AssertionError(f"Unexpected name {name!r} in {names}")
446443

447-
res = []
444+
res: List[Any] = []
448445
while True:
449446
assert pos < len(names), f"Something went wrong with names={names!r}\nres={res!r}"
450447
name = names[pos]

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
from collections.abc import Iterable
3-
from typing import Any, Callable, List, Optional, Tuple, Union
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44
import numpy as np
55
import onnx
66
import torch
@@ -77,7 +77,7 @@ def steal_forward(
7777
if not isinstance(model, list):
7878
model = [model]
7979
keep_model_forward = {}
80-
storage = {} if dump_file else None
80+
storage: Optional[Dict[Any, Any]] = {} if dump_file else None
8181
for mt in model:
8282
name, m = mt if isinstance(mt, tuple) else ("", mt)
8383
keep_model_forward[id(m)] = (m, m.forward)

0 commit comments

Comments
 (0)