Skip to content

Commit 319ec3a

Browse files
committed
better doc
1 parent c3a79de commit 319ec3a

File tree

3 files changed

+229
-63
lines changed

3 files changed

+229
-63
lines changed
Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import unittest
2+
import torch
23
import transformers
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, hide_stdout
45
from onnx_diagnostic.torch_export_patches import torch_export_patches
5-
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
7+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails, PatchInfo
68
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patched_eager_mask
9+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
710

811

912
class TestPatchDetails(ExtTestCase):
13+
@hide_stdout()
1014
def test_patch_details(self):
1115
details = PatchDetails()
1216
with torch_export_patches(
@@ -20,22 +24,38 @@ def test_patch_details(self):
2024
self.assertGreater(details.n_patches, 1)
2125
data = details.data()
2226
self.assertEqual(len(data), details.n_patches)
23-
for kind, f1, f2 in details.patched:
24-
raw = details.format_diff(f1, f2, kind=kind, format="raw")
27+
for patch in details.patched:
28+
_kind, f1, f2 = patch.family, patch.function_to_patch, patch.patch
29+
raw = patch.format_diff(format="raw")
2530
if callable(f1):
2631
self.assertIn(f1.__name__, raw)
2732
self.assertIn(f2.__name__, raw)
28-
rst = details.format_diff(f1, f2, kind=kind, format="rst")
33+
rst = patch.format_diff(format="rst")
2934
self.assertIn("====", rst)
3035

3136
@requires_transformers("4.55")
3237
def test_patch_diff(self):
3338
eager_mask = transformers.masking_utils.eager_mask
3439
self.assertEqual(eager_mask.__name__, "eager_mask")
3540
self.assertEqual(patched_eager_mask.__name__, "patched_eager_mask")
36-
diff = PatchDetails().format_diff(eager_mask, patched_eager_mask, format="rst")
41+
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
3742
self.assertIn("+ # PATCHED:", diff)
3843

44+
def test_involved_patches(self):
45+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
46+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
47+
details = PatchDetails()
48+
with torch_export_patches(
49+
patch_transformers=True, patch_details=details, patch_torch=False
50+
):
51+
ep = torch.export.export(
52+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
53+
)
54+
patches = details.patches_involded_in_graph(ep.graph)
55+
self.assertNotEmpty(patches)
56+
report = details.make_report(patches, format="rst")
57+
self.assertIn("====", report)
58+
3959

4060
if __name__ == "__main__":
4161
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ def patch_module_or_classes(
7070
if isinstance(mod, list):
7171
to_patch = mod
7272
name = "list"
73-
list_name = "auto:list"
73+
list_name = "auto/list"
7474
else:
7575
name, to_patch = get_patches(mod, verbose)
76-
list_name = f"auto:{mod.__name__}"
76+
list_name = f"auto/{mod.__name__.split('.')[-1]}"
7777

7878
res = {}
7979
for cls in to_patch:
Lines changed: 201 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import difflib
22
import inspect
3+
import re
34
import textwrap
4-
from typing import Any, Dict, Callable, List, Optional, Union
5+
from typing import Any, Dict, Callable, List, Optional, Tuple, Union
56

67

78
def clean_code_with_black(code: str) -> str:
@@ -41,72 +42,67 @@ def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
4142
return text
4243

4344

44-
class PatchDetails:
45+
class PatchInfo:
4546
"""
46-
This class is used to store patching information.
47-
This helps understanding which rewriting was applied to which
48-
method of functions.
47+
Stores informations about patches.
48+
49+
:param function_to_patch: function to pathc
50+
:param patch: function patched
51+
:param family: a category, anything to classify the patch
4952
"""
5053

51-
def __init__(self):
52-
self.patched = []
54+
__slots__ = ("family", "function_to_patch", "patch")
5355

54-
def append(self, family: str, function_to_patch: Union[str, Callable], patch: Callable):
56+
def __init__(
57+
self, function_to_patch: Union[str, Callable], patch: Callable, family: str = ""
58+
):
5559
assert callable(function_to_patch) or isinstance(function_to_patch, str), (
5660
f"function_to_patch is not a function but {type(function_to_patch)} "
5761
f"- {function_to_patch!r}"
5862
)
5963
assert callable(
6064
patch
6165
), f"function_to_patch is not a function but {type(patch)} - {patch!r}"
62-
self.patched.append((family, function_to_patch, patch))
66+
self.family = family
67+
self.function_to_patch = function_to_patch
68+
self.patch = patch
6369

64-
@property
65-
def n_patches(self) -> int:
66-
"Returns the number of stored patches."
67-
# Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
68-
return len(self.patched)
70+
def __repr__(self) -> str:
71+
"usual"
72+
return (
73+
(
74+
f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r}, "
75+
f"{self.family!r})"
76+
)
77+
if self.family
78+
else f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r})"
79+
)
6980

70-
def data(self) -> List[Dict[str, Any]]:
71-
"""Returns the data for a dataframe."""
72-
return [dict(zip(["type", "patched", "patch"], v)) for v in self.patched]
81+
def to_tuple(self) -> Tuple[str, Callable, Callable]:
82+
"usual"
83+
return (self.family, self.function_to_patch, self.patch)
7384

74-
def make_diff(self, function_to_patch: Callable, patch: Callable) -> str:
75-
"""
76-
Returns a diff as a string.
85+
def to_dict(self) -> Dict[str, Any]:
86+
"usual"
87+
return {k: getattr(self, k) for k in self.__slots__}
7788

78-
:param function_to_patch: function to pathc
79-
:param patch: function patched
80-
:return: diff
81-
"""
82-
assert callable(function_to_patch) or isinstance(function_to_patch, str), (
83-
f"function_to_patch is not a function but {type(function_to_patch)} "
84-
f"- {function_to_patch!r}"
85-
)
86-
assert callable(patch), (
87-
f"function_to_patch is not a function but {type(patch)} - {patch!r} "
88-
f"(function_to_patch={function_to_patch!r})"
89-
)
90-
if isinstance(function_to_patch, str):
91-
return clean_code_with_black(inspect.getsource(patch))
92-
src1 = clean_code_with_black(inspect.getsource(function_to_patch))
93-
src2 = clean_code_with_black(inspect.getsource(patch))
89+
def make_diff(self) -> str:
90+
"""Returns a diff as a string."""
91+
if isinstance(self.function_to_patch, str):
92+
return clean_code_with_black(inspect.getsource(self.patch))
93+
src1 = clean_code_with_black(inspect.getsource(self.function_to_patch))
94+
src2 = clean_code_with_black(inspect.getsource(self.patch))
9495
return make_diff_code(src1, src2)
9596

96-
def format_diff(
97-
self,
98-
function_to_patch: Callable,
99-
patch: Callable,
100-
kind: Optional[str] = None,
101-
format: str = "raw",
102-
) -> str:
97+
@classmethod
98+
def function_name(cls, f: Callable) -> str:
99+
return f.__qualname__
100+
101+
def format_diff(self, format: str = "raw") -> str:
103102
"""
104103
Format a diff between two function as a string.
105104
106-
:param function_to_patch: function to pathc
107-
:param patch: function patched
108-
:param kind: included in the title
109-
:param raw: ``'raw'`` or ``'rst'``
105+
:param format: ``'raw'`` or ``'rst'``
110106
:return: diff
111107
112108
.. runpython::
@@ -115,20 +111,22 @@ def format_diff(
115111
116112
import transformers
117113
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
118-
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
114+
from onnx_diagnostic.torch_export_patches.patch_details import Patchinfo
119115
120-
diff = PatchDetails().format_diff(eager_mask, patched_eager_mask, format="rst")
116+
diff = Patchinfo(eager_mask, patched_eager_mask).format_diff(format="rst")
121117
print(diff)
122118
"""
123-
diff = self.make_diff(function_to_patch, patch)
124-
kind = kind or ""
119+
diff = self.make_diff()
120+
kind = self.family or ""
125121
if kind:
126122
kind = f"{kind}: "
127-
title = (
128-
f"{kind}{function_to_patch!r} -> {patch.__name__}"
129-
if isinstance(function_to_patch, str)
130-
else f"{kind}{function_to_patch.__name__} -> {patch.__name__}"
123+
function_to_pach_name = (
124+
f"{self.function_to_patch!r}"
125+
if isinstance(self.function_to_patch, str)
126+
else self.function_name(self.function_to_patch)
131127
)
128+
patch_name = self.function_name(self.patch)
129+
title = f"{kind}{function_to_pach_name} -> {patch_name}"
132130
if format == "raw":
133131
return f"{title}\n{diff}"
134132

@@ -142,3 +140,151 @@ def format_diff(
142140
textwrap.indent(diff, prefix=" "),
143141
]
144142
return "\n".join(rows)
143+
144+
145+
class PatchDetails:
146+
"""
147+
This class is used to store patching information.
148+
This helps understanding which rewriting was applied to which
149+
method of functions.
150+
151+
.. runpython::
152+
:showcode:
153+
:rst:
154+
155+
import torch
156+
from onnx_diagnostic.torch_export_patches import torch_export_patches
157+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
158+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
159+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
160+
161+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
162+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
163+
details = PatchDetails()
164+
with torch_export_patches(
165+
patch_transformers=True, patch_details=details, patch_torch=False
166+
):
167+
ep = torch.export.export(
168+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
169+
)
170+
patches = details.patches_involded_in_graph(ep.graph)
171+
report = details.make_report(patches, format="rst")
172+
print(report)
173+
"""
174+
175+
def __init__(self):
176+
self.patched = []
177+
178+
def append(self, family: str, function_to_patch: Union[str, Callable], patch: Callable):
179+
"""
180+
Stores a patch.
181+
182+
:param family: a category, anything to classify the patch
183+
:param function_to_patch: function to pathc
184+
:param patch: function patched
185+
"""
186+
self.patched.append(PatchInfo(function_to_patch, patch, family=family))
187+
188+
@property
189+
def n_patches(self) -> int:
190+
"Returns the number of stored patches."
191+
# Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
192+
return len(self.patched)
193+
194+
def data(self) -> List[Dict[str, Any]]:
195+
"""Returns the data for a dataframe."""
196+
return [p.to_dict() for p in self.patched]
197+
198+
def patches_involded_in_graph(
199+
self, graph: "torch.fx.Graph" # noqa: F821
200+
) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
201+
"""
202+
Enumerates all patches impacting a graph.
203+
The function goes through the graph node (only the main graph) and
204+
looks into the metadata to determine if a listed patch was involved.
205+
206+
:param graph: fx graph
207+
:return: list of nodes impacted by a patch
208+
"""
209+
patches = []
210+
for patch in self.patched:
211+
f = patch.patch
212+
source = inspect.getsourcefile(f)
213+
lines, lineno = inspect.getsourcelines(f)
214+
interval = [lineno, lineno + len(lines)]
215+
patches.append((patch, f, source, interval))
216+
217+
cst = "onnx_diagnostic"
218+
node_stack = []
219+
for node in graph.nodes:
220+
meta = node.meta
221+
if "stack_trace" not in meta:
222+
continue
223+
stack = meta["stack_trace"]
224+
if cst not in stack:
225+
# to reduce the cost of the next iteration
226+
continue
227+
node_stack.append((node, stack))
228+
229+
patch_node = []
230+
for patch, _f, source, interval in patches:
231+
exp = 'File "([^"]*?%s[^"]+?)", line (\\d+)' % cst
232+
reg = re.compile(exp)
233+
for node, stack in node_stack:
234+
occ = reg.findall(stack)
235+
if not occ:
236+
continue
237+
for filename, line_number in occ:
238+
if source.replace("\\", "/").strip("/") != filename.replace(
239+
"\\", "/"
240+
).strip("/"):
241+
continue
242+
line = int(line_number)
243+
if (
244+
line >= interval[0]
245+
and line <= interval[1]
246+
and self.matching_pair(patch, node)
247+
):
248+
patch_node.append((patch, node))
249+
250+
res = {}
251+
for patch, node in patch_node:
252+
if patch not in res:
253+
res[patch] = []
254+
res[patch].append(node)
255+
return list(res.items())
256+
257+
def matching_pair(cls, patch: PatchInfo, node: "torch.fx.Node") -> bool: # noqa: F821
258+
"""
259+
Last validation for a pair. RotaryEmbedding has many rewriting
260+
and they all end up in the same code line.
261+
"""
262+
cls_name = patch.function_to_patch.__qualname__.split(".")[0]
263+
if not cls_name.endswith("RotaryEmbedding"):
264+
return True
265+
return cls_name in str(node.meta)
266+
267+
def make_report(
268+
cls,
269+
patches: List[Tuple[PatchInfo, List["torch.fx.Node"]]], # noqa: F821
270+
format: str = "raw",
271+
) -> str:
272+
"""
273+
Creates a report based on the involved patches.
274+
275+
:param patches: from method :meth:`patches_involded_in_graph`
276+
:param format: format of the report
277+
:return: report
278+
"""
279+
rows = []
280+
for patch, nodes in patches:
281+
rows.append(patch.format_diff(format=format))
282+
rows.append("")
283+
if format == "rst":
284+
rows.extend(["", "", "**impacted nodes**", "", "", ".. code-block:: raw", ""])
285+
for node in nodes:
286+
rows.append(
287+
f" {node.target}({', '.join(map(str,node.args))}) -> {node.name}"
288+
)
289+
rows.append("")
290+
return "\n".join(rows)

0 commit comments

Comments
 (0)