Skip to content

Commit 3625965

Browse files
committed
mypy
1 parent e835791 commit 3625965

File tree

3 files changed

+63
-20
lines changed

3 files changed

+63
-20
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,41 @@ def forward(self, x):
107107
)
108108
self.assertEqual(len(results), 6)
109109

110+
@hide_stdout()
111+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
112+
def test_ep_onnx_sync_a_verbose1(self):
113+
class Model(self.torch.nn.Module):
114+
def forward(self, x):
115+
ry = x.abs()
116+
rz = ry.exp()
117+
rw = rz + 1
118+
ru = rw.log() + rw
119+
return ru
120+
121+
x = self.torch.randn((5, 4))
122+
Model()(x)
123+
ep = self.torch.export.export(
124+
Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},)
125+
)
126+
onx = to_onnx(
127+
ep,
128+
(x,),
129+
dynamic_shapes=({0: self.torch.export.Dim("batch")},),
130+
exporter="onnx-dynamo",
131+
).model_proto
132+
results = list(
133+
run_aligned(
134+
ep,
135+
onx,
136+
args=(x,),
137+
run_cls=ExtendedReferenceEvaluator,
138+
atol=1e-5,
139+
rtol=1e-5,
140+
verbose=1,
141+
),
142+
)
143+
self.assertEqual(len(results), 6)
144+
110145
@hide_stdout()
111146
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
112147
def test_sbs_dict(self):

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def _loop_cmp(
4343
onnx_name: str,
4444
torch_result: torch.Tensor,
4545
verbose: int,
46-
atol: float,
47-
rtol: float,
46+
atol: Optional[float],
47+
rtol: Optional[float],
4848
i_torch: int,
4949
i_onnx: int,
5050
str_kws: Dict[str, bool],
@@ -140,6 +140,7 @@ def _loop_onnx_node(
140140
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
141141
)
142142
elif verbose == 1:
143+
loop.update(i_torch + i_onnx)
143144
loop.set_description(
144145
f"ep {i_torch}/{len(ep_graph_nodes)} nx {i_onnx}/{len(onx.graph.node)} "
145146
f"{status.to_str()}"
@@ -435,7 +436,7 @@ def _preparation_with_onnx_model(
435436
t = torch_results[init.name]
436437
torch_names_to_onnx_names[init.name] = init.name
437438
elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
438-
new_names = [ # type: ignore[assignment]
439+
new_names = [
439440
k
440441
for k in rev_init_aliases[init.name]
441442
if k in torch_results and k not in skip_mapping_torch_onnx
@@ -657,8 +658,8 @@ def forward(self, x):
657658
-v 1 --atol=0.1 --rtol=1
658659
"""
659660
assert callable(run_cls), f"run_cls={run_cls} not a callable"
660-
already_yielded = {} # type: ignore[var-annotated]
661-
reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment]
661+
already_yielded = {}
662+
reset_names = set(reset_names) if reset_names else set()
662663
str_kws = dict(with_shape=True, with_device=True)
663664
has_cuda = any(
664665
(isinstance(t, torch.Tensor) and t.is_cuda)
@@ -777,25 +778,28 @@ def forward(self, x):
777778
if verbose == 1:
778779
import tqdm
779780

780-
loop = tqdm.tqdm(list(enumerate(ep_graph_nodes)))
781+
loop = tqdm.tqdm(total=len(ep_graph_nodes) + len(onx.graph.node))
781782
else:
782-
loop = list(enumerate(ep_graph_nodes))
783+
loop = None
783784

784785
already_run: Set[int] = set()
785786
ep_durations = {}
786787
status = StatusRunAligned()
787-
for i, node in loop:
788+
for i_torch, node in enumerate(ep_graph_nodes):
788789
if verbose > 1:
789790
if node.op == "call_function":
790791
print(
791-
f"[run_aligned] run ep.graph.nodes[{i}]: "
792+
f"[run_aligned] run ep.graph.nodes[{i_torch}]: "
792793
f"{node.op}[{node.target}] -> {node.name!r}"
793794
)
794795
else:
795-
print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
796+
print(
797+
f"[run_aligned] run ep.graph.nodes[{i_torch}]: {node.op} -> {node.name!r}"
798+
)
796799
elif verbose == 1:
800+
loop.update(i_torch + last_position)
797801
loop.set_description(
798-
f"ep {i}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} "
802+
f"ep {i_torch}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} "
799803
f"{status.to_str()}"
800804
)
801805

@@ -816,7 +820,7 @@ def forward(self, x):
816820
print(f"[run_aligned-ep] =ags: {node.name}={string_type(t, **str_kws)}")
817821
# Otherwise, it is an input.
818822
record = RunAlignedRecord(
819-
ep_id_node=i,
823+
ep_id_node=i_torch,
820824
onnx_id_node=-1,
821825
ep_name=node.name,
822826
onnx_name=torch_names_to_onnx_names[node.name],
@@ -848,7 +852,7 @@ def forward(self, x):
848852
f"{node.name}={string_type(t, **str_kws)}"
849853
)
850854
record = RunAlignedRecord(
851-
ep_id_node=i,
855+
ep_id_node=i_torch,
852856
onnx_id_node=-1,
853857
ep_name=node.name,
854858
onnx_name=torch_names_to_onnx_names[node.name],
@@ -874,7 +878,7 @@ def forward(self, x):
874878
f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}"
875879
)
876880
yield RunAlignedRecord(
877-
ep_id_node=i,
881+
ep_id_node=i_torch,
878882
ep_name=node.name,
879883
ep_target="placeholder",
880884
ep_shape_type=string_type(t, **str_kws),
@@ -886,7 +890,7 @@ def forward(self, x):
886890
begin = time.perf_counter()
887891
new_outputs = run_fx_node(node, args, kwargs)
888892
duration = time.perf_counter() - begin
889-
ep_durations[i] = duration
893+
ep_durations[i_torch] = duration
890894
if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)):
891895
new_outputs = (new_outputs,)
892896

@@ -906,7 +910,7 @@ def forward(self, x):
906910
if "onnx" in positions[n]:
907911
max_pos = max(max_pos, positions[n]["onnx"])
908912
if "fx" in positions[n]:
909-
if positions[n]["fx"] > i:
913+
if positions[n]["fx"] > i_torch:
910914
max_pos = -2
911915
break
912916
if max_pos == -2:
@@ -923,7 +927,7 @@ def forward(self, x):
923927
ep_behind = False
924928
for iname in node.output:
925929
if iname in positions and "fx" in positions[iname]:
926-
if positions[iname]["fx"] > i:
930+
if positions[iname]["fx"] > i_torch:
927931
ep_behind = True
928932
break
929933
if ep_behind:
@@ -937,7 +941,7 @@ def forward(self, x):
937941
torch_results,
938942
ep_durations,
939943
use_tensor,
940-
i,
944+
i_torch,
941945
i_onnx,
942946
name_to_ep_node,
943947
run_cls_kwargs,
@@ -978,7 +982,7 @@ def forward(self, x):
978982
torch_results,
979983
ep_durations,
980984
use_tensor,
981-
i,
985+
i_torch,
982986
i_onnx,
983987
name_to_ep_node,
984988
run_cls_kwargs,

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ disable_error_code = ["arg-type", "assignment", "attr-defined", "index", "misc",
108108

109109
[[tool.mypy.overrides]]
110110
module = ["onnx_diagnostic.torch_models.*"]
111-
disable_error_code = ["attr-defined", "call-overload", "operator"]
111+
disable_error_code = ["assignment", "attr-defined", "call-overload", "operator"]
112+
113+
[[tool.mypy.overrides]]
114+
module = ["onnx_diagnostic.torch_onnx.sbs"]
115+
disable_error_code = ["var-annotated"]
112116

113117
[tool.ruff]
114118

0 commit comments

Comments
 (0)