Skip to content

Commit 1d7e449

Browse files
committed
fixes
1 parent 3501d18 commit 1d7e449

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

_unittests/ut_torch_models/test_hghub_mode_rewrite.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_errors
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
ignore_errors,
7+
requires_torch,
8+
)
49
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
510
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
611
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -14,7 +19,8 @@ def test_code_needing_rewriting(self):
1419

1520
@hide_stdout()
1621
@ignore_errors(OSError)
17-
def test_export_rewritin_bart(self):
22+
@requires_torch("2.8")
23+
def test_export_rewriting_bart(self):
1824
mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
1925
data = get_untrained_model_with_inputs(mid, verbose=1)
2026
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import ast
22
import copy
33
import contextlib
4+
import difflib
45
import inspect
6+
import os
57
import types
68
import textwrap
79
import sys
@@ -886,9 +888,10 @@ def torch_export_rewrite(
886888
to discover them, a method is defined by its class (a type) and its name
887889
if the class is local, by itself otherwise, it can also be a model,
888890
in that case, the function calls :func:`code_needing_rewriting
889-
<onnx_dynamic.torch_export_patches.patch_module.helper.code_needing_rewriting>`
891+
<onnx_dynamic.torch_export_patches.patch_module_helper.code_needing_rewriting>`
890892
to retrieve the necessary rewriting
891-
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
893+
:param dump_rewriting: dumps rewriting into that folder, if it does not exists,
894+
it creates it.
892895
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
893896
``verbose=1`` shows the rewritten function,
894897
``verbose=2`` shows the rewritten code as well
@@ -1008,19 +1011,24 @@ def forward(self, x, y):
10081011
print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}")
10091012
keep[cls, name] = to_rewrite
10101013
if dump_rewriting:
1011-
filename = f"{dump_rewriting}.{kind}.{cls_name}.{name}.original.py"
1014+
if not os.path.exists(dump_rewriting):
1015+
os.makedirs(dump_rewriting)
1016+
filename1 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.original.py")
10121017
if verbose:
1013-
print(f"[torch_export_rewrite] dump original code in {filename!r}")
1014-
with open(filename, "w") as f:
1015-
code = inspect.getsource(to_rewrite)
1018+
print(f"[torch_export_rewrite] dump original code in {filename1!r}")
1019+
with open(filename1, "w") as f:
1020+
code = _clean_code(inspect.getsource(to_rewrite))
10161021
f.write(code)
10171022
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
10181023
if dump_rewriting:
1019-
filename = f"{dump_rewriting}.{kind}.{cls_name}.{name}.rewritten.py"
1024+
filename2 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.rewritten.py")
10201025
if verbose:
1021-
print(f"[torch_export_rewrite] dump rewritten code in {filename!r}")
1022-
with open(filename, "w") as f:
1023-
f.write(rewr.code)
1026+
print(f"[torch_export_rewrite] dump rewritten code in {filename2!r}")
1027+
with open(filename2, "w") as f:
1028+
rcode = _clean_code(rewr.code)
1029+
f.write(rcode)
1030+
diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
1031+
make_diff(code, rcode, diff)
10241032
setattr(cls, name, rewr.func)
10251033

10261034
try:
@@ -1030,3 +1038,35 @@ def forward(self, x, y):
10301038
if verbose:
10311039
print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
10321040
setattr(cls, name, me)
1041+
1042+
1043+
def _clean_code(code: str) -> str:
1044+
try:
1045+
import black
1046+
except ImportError:
1047+
return code
1048+
return black.format_str(code, mode=black.FileMode(line_length=98))
1049+
1050+
1051+
def make_diff(code1: str, code2: str, output: Optional[str] = None) -> str:
1052+
"""
1053+
Creates a diff between two codes.
1054+
1055+
:param code1: first code
1056+
:param code2: second code
1057+
:param output: if not empty, stores the output in this file
1058+
:return: diff
1059+
"""
1060+
text = "\n".join(
1061+
difflib.unified_diff(
1062+
code1.strip().splitlines(),
1063+
code2.strip().splitlines(),
1064+
fromfile="original",
1065+
tofile="rewritten",
1066+
lineterm="",
1067+
)
1068+
)
1069+
if output:
1070+
with open(output, "w") as f:
1071+
f.write(text)
1072+
return text

0 commit comments

Comments
 (0)