11import ast
22import copy
33import contextlib
4+ import difflib
45import inspect
6+ import os
57import types
68import textwrap
79import sys
@@ -886,9 +888,10 @@ def torch_export_rewrite(
886888 to discover them, a method is defined by its class (a type) and its name
887889 if the class is local, by itself otherwise, it can also be a model,
888890 in that case, the function calls :func:`code_needing_rewriting
889- <onnx_dynamic.torch_export_patches.patch_module.helper .code_needing_rewriting>`
891+ <onnx_dynamic.torch_export_patches.patch_module_helper .code_needing_rewriting>`
890892 to retrieve the necessary rewriting
891- :param dump_rewriting: dumps rewriting information in file beginning with that prefix
893+ :param dump_rewriting: dumps rewriting into that folder, if it does not exists,
894+ it creates it.
892895 :param verbose: verbosity, up to 10, 10 shows the rewritten code,
893896 ``verbose=1`` shows the rewritten function,
894897 ``verbose=2`` shows the rewritten code as well
@@ -1008,19 +1011,24 @@ def forward(self, x, y):
10081011 print (f"[torch_export_rewrite] rewrites { kind } { cls .__name__ } .{ name } " )
10091012 keep [cls , name ] = to_rewrite
10101013 if dump_rewriting :
1011- filename = f"{ dump_rewriting } .{ kind } .{ cls_name } .{ name } .original.py"
1014+ if not os .path .exists (dump_rewriting ):
1015+ os .makedirs (dump_rewriting )
1016+ filename1 = os .path .join (dump_rewriting , f"{ kind } .{ cls_name } .{ name } .original.py" )
10121017 if verbose :
1013- print (f"[torch_export_rewrite] dump original code in { filename !r} " )
1014- with open (filename , "w" ) as f :
1015- code = inspect .getsource (to_rewrite )
1018+ print (f"[torch_export_rewrite] dump original code in { filename1 !r} " )
1019+ with open (filename1 , "w" ) as f :
1020+ code = _clean_code ( inspect .getsource (to_rewrite ) )
10161021 f .write (code )
10171022 rewr = transform_method (to_rewrite , verbose = max (verbose - 1 , 0 ), ** kws )
10181023 if dump_rewriting :
1019- filename = f"{ dump_rewriting } . { kind } .{ cls_name } .{ name } .rewritten.py"
1024+ filename2 = os . path . join ( dump_rewriting , f"{ kind } .{ cls_name } .{ name } .rewritten.py" )
10201025 if verbose :
1021- print (f"[torch_export_rewrite] dump rewritten code in { filename !r} " )
1022- with open (filename , "w" ) as f :
1023- f .write (rewr .code )
1026+ print (f"[torch_export_rewrite] dump rewritten code in { filename2 !r} " )
1027+ with open (filename2 , "w" ) as f :
1028+ rcode = _clean_code (rewr .code )
1029+ f .write (rcode )
1030+ diff = os .path .join (dump_rewriting , f"{ kind } .{ cls_name } .{ name } .diff" )
1031+ make_diff (code , rcode , diff )
10241032 setattr (cls , name , rewr .func )
10251033
10261034 try :
@@ -1030,3 +1038,35 @@ def forward(self, x, y):
10301038 if verbose :
10311039 print (f"[torch_export_rewrite] restored { kind } { cls .__name__ } .{ name } " )
10321040 setattr (cls , name , me )
1041+
1042+
1043+ def _clean_code (code : str ) -> str :
1044+ try :
1045+ import black
1046+ except ImportError :
1047+ return code
1048+ return black .format_str (code , mode = black .FileMode (line_length = 98 ))
1049+
1050+
1051+ def make_diff (code1 : str , code2 : str , output : Optional [str ] = None ) -> str :
1052+ """
1053+ Creates a diff between two codes.
1054+
1055+ :param code1: first code
1056+ :param code2: second code
1057+ :param output: if not empty, stores the output in this file
1058+ :return: diff
1059+ """
1060+ text = "\n " .join (
1061+ difflib .unified_diff (
1062+ code1 .strip ().splitlines (),
1063+ code2 .strip ().splitlines (),
1064+ fromfile = "original" ,
1065+ tofile = "rewritten" ,
1066+ lineterm = "" ,
1067+ )
1068+ )
1069+ if output :
1070+ with open (output , "w" ) as f :
1071+ f .write (text )
1072+ return text
0 commit comments