Skip to content

Commit 5623d30

Browse files
AutinMitrapytorchmergebot
authored andcommitted
[Minimizer] Gracefully exit when there is no discrepancy in block mode (pytorch#154076)
Summary: Previously, when there is no discrepancy in results for block mode, net_min_base will throw an OOB error. This occurs due to the block _block_traverse_impl returning an OOB after exhausting subgraphs all the way down to a single node There is also an issue where we may get an unsound subgraph (i.e. mark an earlier node as the "end" even if the correct end is later). This is due to an incorrect check (start_idx == mid) where there can possibly be two values left before the program pre-maturely returns Test Plan: Buck UI: https://www.internalfb.com/buck2/52524c26-ace5-4593-8a4b-843a54eb206a Test UI: https://www.internalfb.com/intern/testinfra/testrun/3096224973363310 Network: Up: 0B Down: 15MiB (reSessionID-cd404e97-395f-49fc-8381-373e90a1378f) Executing actions. Remaining 0/1 Command: test. Time elapsed: 53.7s Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0 Differential Revision: D75143242 Pull Request resolved: pytorch#154076 Approved by: https://github.com/jfix71
1 parent 8342b93 commit 5623d30

File tree

2 files changed

+146
-21
lines changed

2 files changed

+146
-21
lines changed

test/fx/test_net_min_base.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Owner(s): ["module: fx"]
2+
3+
from unittest import mock
4+
5+
import torch
6+
from torch.fx.passes.net_min_base import (
7+
_MinimizerBase,
8+
_MinimizerSettingBase,
9+
FxNetMinimizerResultMismatchError,
10+
)
11+
from torch.fx.passes.tools_common import Names
12+
from torch.testing._internal.common_utils import TestCase
13+
14+
15+
class TestNetMinBaseBlock(TestCase):
16+
def setUp(self) -> None:
17+
# Setup test fixtures for each test method
18+
19+
class SimpleModule(torch.nn.Module):
20+
def __init__(self) -> None:
21+
super().__init__()
22+
self.linear = torch.nn.Linear(10, 5)
23+
self.linear2 = torch.nn.Linear(5, 5)
24+
self.relu = torch.nn.ReLU()
25+
26+
def forward(self, x: torch.Tensor) -> torch.Tensor:
27+
x = self.linear(x)
28+
x = self.linear2(x)
29+
x = self.relu(x)
30+
return x
31+
32+
self.compare_fn = mock.MagicMock()
33+
34+
self.module = torch.fx.symbolic_trace(SimpleModule())
35+
self.sample_input = (torch.randn(2, 10),)
36+
self.settings = _MinimizerSettingBase(traverse_method="block")
37+
self.minimizer = _MinimizerBase(
38+
module=self.module,
39+
sample_input=self.sample_input,
40+
settings=self.settings,
41+
compare_fn=self.compare_fn,
42+
)
43+
self.report = []
44+
45+
def assert_problematic_nodes(self, culprit_names: Names) -> None:
46+
"""
47+
Quick helper function to assert that a set of nodes (when present together in a subgraph) cause a discrepancy
48+
"""
49+
with mock.patch("torch.fx.passes.net_min_base._MinimizerBase._run_and_compare"):
50+
51+
def run_and_compare_side_effect(
52+
split_module: torch.fx.GraphModule,
53+
submod_name: str,
54+
output_names: Names,
55+
report_idx: int = -1,
56+
) -> None:
57+
submodule = getattr(split_module, submod_name)
58+
59+
# Remove input/output layer
60+
names = set([node.name for node in submodule.graph.nodes][1:-1])
61+
if set(culprit_names) <= names:
62+
raise FxNetMinimizerResultMismatchError
63+
64+
self.minimizer._run_and_compare.side_effect = run_and_compare_side_effect
65+
66+
# Every single node should be a discrepancy
67+
culprits = self.minimizer.minimize()
68+
self.assertEqual({node.name for node in culprits}, set(culprit_names))
69+
70+
def test_no_discrepancy(self) -> None:
71+
# No discrepancies should handle gracefully with an empty set
72+
with (
73+
mock.patch("torch.fx.passes.net_min_base._MinimizerBase.run_a"),
74+
mock.patch("torch.fx.passes.net_min_base._MinimizerBase.run_b"),
75+
):
76+
# Have both run_a and run_b return the same result
77+
return_value = torch.zeros((2, 5))
78+
self.minimizer.run_a.return_value = return_value
79+
self.minimizer.run_b.return_value = return_value
80+
self.compare_fn.return_value = (0, True)
81+
82+
# There should be no discrepancy between the two, and thus we should receive an empty set
83+
culprits = self.minimizer.minimize()
84+
self.assertEqual(culprits, set())
85+
86+
def test_all_nodes_discrepancy(self) -> None:
87+
self.assert_problematic_nodes(["linear", "linear2", "relu"])
88+
89+
def test_first_node_discrepancy(self) -> None:
90+
self.assert_problematic_nodes(["linear"])
91+
92+
def test_last_node_discrepancy(self) -> None:
93+
self.assert_problematic_nodes(["relu"])
94+
95+
def test_middle_node_discrepancy(self) -> None:
96+
self.assert_problematic_nodes(["linear2"])
97+
98+
def test_contiguous_partial_discrepancy_end(self) -> None:
99+
self.assert_problematic_nodes(["linear2", "relu"])
100+
101+
def test_continugous_partial_discrepancy_beginning(self) -> None:
102+
self.assert_problematic_nodes(["linear", "linear2"])

torch/fx/passes/net_min_base.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: allow-untyped-defs
22
import logging
33
from dataclasses import dataclass
4-
from typing import Any, Callable, Optional
4+
from typing import Any, Callable, cast, Optional
55

66
import torch
77
import torch.fx
@@ -539,7 +539,7 @@ def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
539539

540540
def _block_traverse_impl(
541541
self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool
542-
) -> int:
542+
) -> Optional[int]:
543543
"""
544544
Recursive block search implementation.
545545
find_last_node: If True, search for the last node which result in numerics difference
@@ -588,7 +588,7 @@ def _block_traverse_impl(
588588
f"Culprits found from node {first_node_name} to {last_node_name}."
589589
)
590590

591-
if start_idx == mid:
591+
if start_idx == mid == end_idx:
592592
report.extend(
593593
[
594594
"This is the last node in the sub-module. ",
@@ -616,16 +616,19 @@ def _block_traverse_impl(
616616
f"Culprits not found from node start to {mid}:{nodes[mid].name}."
617617
)
618618

619-
if start_idx == mid:
620-
report.extend(
621-
[
622-
"This is the last node in the sub-module. ",
623-
"Search in the current branch is successful with node",
624-
f"{start_idx}, node name: {nodes[start_idx].name}.",
625-
]
626-
)
627-
self.print_report(report)
628-
return start_idx + 1 if find_last_node else start_idx - 1
619+
if start_idx == mid == end_idx:
620+
# We did not find anything if the pointers have not moved
621+
if (start_idx == 0 and not find_last_node) or (
622+
start_idx == len(nodes) - 1 and find_last_node
623+
):
624+
report.append(
625+
f"At {'last' if find_last_node else 'first'} node, no culprits found."
626+
)
627+
self.print_report(report)
628+
return None
629+
630+
# Otherwise, we have converged on the border between discrepancy and valid
631+
return start_idx + (1 if find_last_node else -1)
629632

630633
report.append(
631634
"Proceed to split and lower the halves of the current "
@@ -661,39 +664,59 @@ def _block_traverse(
661664

662665
start_idx = 0
663666
end_idx = len(nodes) - 1
667+
668+
final_start_idx: Optional[int] = start_idx
669+
final_end_idx: Optional[int] = end_idx
670+
664671
run_both = True if find_last_node is None else False
665672

666673
# step 1: find (0, end_idx) of culprit block
667674
if run_both or find_last_node:
668675
last_node_report.append("Start searching for last node in culprit")
669676
self.print_report(last_node_report)
670-
end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
677+
final_end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
678+
679+
if final_end_idx is None:
680+
last_node_report.append("No culprits found")
681+
self.print_report(last_node_report)
682+
return culprits
683+
671684
last_node_report.extend(
672-
["Finish Pass 1", f"Find end_idx = {end_idx}:{nodes[end_idx].name}"]
685+
[
686+
"Finish Pass 1",
687+
f"Find end_idx = {final_end_idx}:{nodes[final_end_idx].name}",
688+
]
673689
)
674690
self.print_report(last_node_report)
675691

676692
# step 2: reduce culprit block to (start_idx, end_idx)
677693
if run_both or not find_last_node:
678694
first_node_report = ["Start searching for first node in culprit"]
679695
self.print_report(first_node_report)
680-
start_idx = self._block_traverse_impl(
681-
nodes[0 : end_idx + 1], start_idx, end_idx, False
696+
final_start_idx = self._block_traverse_impl(
697+
nodes[0 : end_idx + 1], start_idx, final_end_idx or end_idx, False
682698
)
699+
700+
if final_start_idx is None:
701+
last_node_report.append("No culprits found")
702+
self.print_report(last_node_report)
703+
return culprits
704+
683705
first_node_report.append("*" * 50)
684706
self.reports.append(first_node_report)
685707
first_node_report.extend(
686708
[
687709
"Finish Pass 2",
688-
f"Find start_idx = {start_idx}:{nodes[start_idx].name}",
710+
f"Find start_idx = {final_start_idx}:{nodes[final_start_idx].name}",
689711
]
690712
)
691713
self.print_report(first_node_report)
692714

693-
# step 3: form module with minimum culprits
694-
culprits.update(nodes[start_idx : end_idx + 1])
715+
# step 3: form module with minimum culprits. These indexes are guaranteed to exist
716+
range_start, range_end = cast(int, final_start_idx), cast(int, final_end_idx)
717+
culprits.update(nodes[range_start : range_end + 1])
695718
result_report = [
696-
f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"
719+
f"Finish searching, found minimum block ({nodes[range_start]},{nodes[range_end]})"
697720
]
698721
self.reports.append(result_report)
699722
self.print_report(result_report)

0 commit comments

Comments
 (0)