11import difflib
22import inspect
3+ import re
34import textwrap
4- from typing import Any , Dict , Callable , List , Optional , Union
5+ from typing import Any , Dict , Callable , List , Optional , Tuple , Union
56
67
78def clean_code_with_black (code : str ) -> str :
@@ -41,72 +42,67 @@ def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
4142 return text
4243
4344
44- class PatchDetails :
45+ class PatchInfo :
4546 """
46- This class is used to store patching information.
47- This helps understanding which rewriting was applied to which
48- method of functions.
47+ Stores informations about patches.
48+
49+ :param function_to_patch: function to pathc
50+ :param patch: function patched
51+ :param family: a category, anything to classify the patch
4952 """
5053
51- def __init__ (self ):
52- self .patched = []
54+ __slots__ = ("family" , "function_to_patch" , "patch" )
5355
54- def append (self , family : str , function_to_patch : Union [str , Callable ], patch : Callable ):
56+ def __init__ (
57+ self , function_to_patch : Union [str , Callable ], patch : Callable , family : str = ""
58+ ):
5559 assert callable (function_to_patch ) or isinstance (function_to_patch , str ), (
5660 f"function_to_patch is not a function but { type (function_to_patch )} "
5761 f"- { function_to_patch !r} "
5862 )
5963 assert callable (
6064 patch
6165 ), f"function_to_patch is not a function but { type (patch )} - { patch !r} "
62- self .patched .append ((family , function_to_patch , patch ))
66+ self .family = family
67+ self .function_to_patch = function_to_patch
68+ self .patch = patch
6369
64- @property
65- def n_patches (self ) -> int :
66- "Returns the number of stored patches."
67- # Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
68- return len (self .patched )
70+ def __repr__ (self ) -> str :
71+ "usual"
72+ return (
73+ (
74+ f"{ self .__class__ .__name__ } ({ self .function_to_patch !r} , { self .patch !r} , "
75+ f"{ self .family !r} )"
76+ )
77+ if self .family
78+ else f"{ self .__class__ .__name__ } ({ self .function_to_patch !r} , { self .patch !r} )"
79+ )
6980
70- def data (self ) -> List [ Dict [ str , Any ] ]:
71- """Returns the data for a dataframe."" "
72- return [ dict ( zip ([ "type" , "patched" , "patch" ], v )) for v in self .patched ]
81+ def to_tuple (self ) -> Tuple [ str , Callable , Callable ]:
82+ "usual "
83+ return ( self . family , self . function_to_patch , self .patch )
7384
74- def make_diff (self , function_to_patch : Callable , patch : Callable ) -> str :
75- "" "
76- Returns a diff as a string.
85+ def to_dict (self ) -> Dict [ str , Any ] :
86+ "usual "
87+ return { k : getattr ( self , k ) for k in self . __slots__ }
7788
78- :param function_to_patch: function to pathc
79- :param patch: function patched
80- :return: diff
81- """
82- assert callable (function_to_patch ) or isinstance (function_to_patch , str ), (
83- f"function_to_patch is not a function but { type (function_to_patch )} "
84- f"- { function_to_patch !r} "
85- )
86- assert callable (patch ), (
87- f"function_to_patch is not a function but { type (patch )} - { patch !r} "
88- f"(function_to_patch={ function_to_patch !r} )"
89- )
90- if isinstance (function_to_patch , str ):
91- return clean_code_with_black (inspect .getsource (patch ))
92- src1 = clean_code_with_black (inspect .getsource (function_to_patch ))
93- src2 = clean_code_with_black (inspect .getsource (patch ))
89+ def make_diff (self ) -> str :
90+ """Returns a diff as a string."""
91+ if isinstance (self .function_to_patch , str ):
92+ return clean_code_with_black (inspect .getsource (self .patch ))
93+ src1 = clean_code_with_black (inspect .getsource (self .function_to_patch ))
94+ src2 = clean_code_with_black (inspect .getsource (self .patch ))
9495 return make_diff_code (src1 , src2 )
9596
96- def format_diff (
97- self ,
98- function_to_patch : Callable ,
99- patch : Callable ,
100- kind : Optional [str ] = None ,
101- format : str = "raw" ,
102- ) -> str :
97+ @classmethod
98+ def function_name (cls , f : Callable ) -> str :
99+ return f .__qualname__
100+
101+ def format_diff (self , format : str = "raw" ) -> str :
103102 """
104103 Format a diff between two function as a string.
105104
106- :param function_to_patch: function to pathc
107- :param patch: function patched
108- :param kind: included in the title
109- :param raw: ``'raw'`` or ``'rst'``
105+ :param format: ``'raw'`` or ``'rst'``
110106 :return: diff
111107
112108 .. runpython::
@@ -115,20 +111,22 @@ def format_diff(
115111
116112 import transformers
117113 import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
118- from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
114+ from onnx_diagnostic.torch_export_patches.patch_details import Patchinfo
119115
120- diff = PatchDetails().format_diff( eager_mask, patched_eager_mask, format="rst")
116+ diff = Patchinfo( eager_mask, patched_eager_mask).format_diff( format="rst")
121117 print(diff)
122118 """
123- diff = self .make_diff (function_to_patch , patch )
124- kind = kind or ""
119+ diff = self .make_diff ()
120+ kind = self . family or ""
125121 if kind :
126122 kind = f"{ kind } : "
127- title = (
128- f"{ kind } { function_to_patch !r} -> { patch . __name__ } "
129- if isinstance (function_to_patch , str )
130- else f" { kind } { function_to_patch . __name__ } -> { patch . __name__ } "
123+ function_to_pach_name = (
124+ f"{ self . function_to_patch !r} "
125+ if isinstance (self . function_to_patch , str )
126+ else self . function_name ( self . function_to_patch )
131127 )
128+ patch_name = self .function_name (self .patch )
129+ title = f"{ kind } { function_to_pach_name } -> { patch_name } "
132130 if format == "raw" :
133131 return f"{ title } \n { diff } "
134132
@@ -142,3 +140,151 @@ def format_diff(
142140 textwrap .indent (diff , prefix = " " ),
143141 ]
144142 return "\n " .join (rows )
143+
144+
145+ class PatchDetails :
146+ """
147+ This class is used to store patching information.
148+ This helps understanding which rewriting was applied to which
149+ method of functions.
150+
151+ .. runpython::
152+ :showcode:
153+ :rst:
154+
155+ import torch
156+ from onnx_diagnostic.torch_export_patches import torch_export_patches
157+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
158+ from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
159+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
160+
161+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
162+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
163+ details = PatchDetails()
164+ with torch_export_patches(
165+ patch_transformers=True, patch_details=details, patch_torch=False
166+ ):
167+ ep = torch.export.export(
168+ model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
169+ )
170+ patches = details.patches_involded_in_graph(ep.graph)
171+ report = details.make_report(patches, format="rst")
172+ print(report)
173+ """
174+
175+ def __init__ (self ):
176+ self .patched = []
177+
178+ def append (self , family : str , function_to_patch : Union [str , Callable ], patch : Callable ):
179+ """
180+ Stores a patch.
181+
182+ :param family: a category, anything to classify the patch
183+ :param function_to_patch: function to pathc
184+ :param patch: function patched
185+ """
186+ self .patched .append (PatchInfo (function_to_patch , patch , family = family ))
187+
188+ @property
189+ def n_patches (self ) -> int :
190+ "Returns the number of stored patches."
191+ # Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
192+ return len (self .patched )
193+
194+ def data (self ) -> List [Dict [str , Any ]]:
195+ """Returns the data for a dataframe."""
196+ return [p .to_dict () for p in self .patched ]
197+
198+ def patches_involded_in_graph (
199+ self , graph : "torch.fx.Graph" # noqa: F821
200+ ) -> List [Tuple [PatchInfo , List ["torch.fx.Node" ]]]: # noqa: F821
201+ """
202+ Enumerates all patches impacting a graph.
203+ The function goes through the graph node (only the main graph) and
204+ looks into the metadata to determine if a listed patch was involved.
205+
206+ :param graph: fx graph
207+ :return: list of nodes impacted by a patch
208+ """
209+ patches = []
210+ for patch in self .patched :
211+ f = patch .patch
212+ source = inspect .getsourcefile (f )
213+ lines , lineno = inspect .getsourcelines (f )
214+ interval = [lineno , lineno + len (lines )]
215+ patches .append ((patch , f , source , interval ))
216+
217+ cst = "onnx_diagnostic"
218+ node_stack = []
219+ for node in graph .nodes :
220+ meta = node .meta
221+ if "stack_trace" not in meta :
222+ continue
223+ stack = meta ["stack_trace" ]
224+ if cst not in stack :
225+ # to reduce the cost of the next iteration
226+ continue
227+ node_stack .append ((node , stack ))
228+
229+ patch_node = []
230+ for patch , _f , source , interval in patches :
231+ exp = 'File "([^"]*?%s[^"]+?)", line (\\ d+)' % cst
232+ reg = re .compile (exp )
233+ for node , stack in node_stack :
234+ occ = reg .findall (stack )
235+ if not occ :
236+ continue
237+ for filename , line_number in occ :
238+ if source .replace ("\\ " , "/" ).strip ("/" ) != filename .replace (
239+ "\\ " , "/"
240+ ).strip ("/" ):
241+ continue
242+ line = int (line_number )
243+ if (
244+ line >= interval [0 ]
245+ and line <= interval [1 ]
246+ and self .matching_pair (patch , node )
247+ ):
248+ patch_node .append ((patch , node ))
249+
250+ res = {}
251+ for patch , node in patch_node :
252+ if patch not in res :
253+ res [patch ] = []
254+ res [patch ].append (node )
255+ return list (res .items ())
256+
257+ def matching_pair (cls , patch : PatchInfo , node : "torch.fx.Node" ) -> bool : # noqa: F821
258+ """
259+ Last validation for a pair. RotaryEmbedding has many rewriting
260+ and they all end up in the same code line.
261+ """
262+ cls_name = patch .function_to_patch .__qualname__ .split ("." )[0 ]
263+ if not cls_name .endswith ("RotaryEmbedding" ):
264+ return True
265+ return cls_name in str (node .meta )
266+
267+ def make_report (
268+ cls ,
269+ patches : List [Tuple [PatchInfo , List ["torch.fx.Node" ]]], # noqa: F821
270+ format : str = "raw" ,
271+ ) -> str :
272+ """
273+ Creates a report based on the involved patches.
274+
275+ :param patches: from method :meth:`patches_involded_in_graph`
276+ :param format: format of the report
277+ :return: report
278+ """
279+ rows = []
280+ for patch , nodes in patches :
281+ rows .append (patch .format_diff (format = format ))
282+ rows .append ("" )
283+ if format == "rst" :
284+ rows .extend (["" , "" , "**impacted nodes**" , "" , "" , ".. code-block:: raw" , "" ])
285+ for node in nodes :
286+ rows .append (
287+ f" { node .target } ({ ', ' .join (map (str ,node .args ))} ) -> { node .name } "
288+ )
289+ rows .append ("" )
290+ return "\n " .join (rows )
0 commit comments