Skip to content

Commit dc41fe9

Browse files
committed
patches
1 parent ad25591 commit dc41fe9

File tree

4 files changed

+109
-69
lines changed

4 files changed

+109
-69
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.torch_export_patches import torch_export_patches
4+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
5+
6+
7+
class TestPatchDetails(ExtTestCase):
8+
def test_patch_details(self):
9+
details = PatchDetails()
10+
with torch_export_patches(
11+
patch_transformers=True,
12+
verbose=10,
13+
patch_torch=True,
14+
patch_diffusers=True,
15+
patch_details=details,
16+
):
17+
pass
18+
self.assertGreater(details.n_patches, 1)
19+
data = details.data()
20+
print(data)
21+
22+
23+
if __name__ == "__main__":
24+
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
5252
return name, to_patch
5353

5454

55-
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
55+
def patch_module_or_classes(
56+
mod, verbose: int = 0, patch_details: Optional[PatchDetails] = None
57+
) -> Dict[type, Dict[type, Callable]]:
5658
"""
5759
Applies all patches defined in classes prefixed by ``patched_``
5860
``cls._PATCHED_CLASS_`` defines the class to patch,
@@ -62,13 +64,16 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
6264
6365
:param mod: module of list of clsses to patch
6466
:param verbose: verbosity
67+
:param patch_details: used to store information about the applied patches
6568
:return: patch info
6669
"""
6770
if isinstance(mod, list):
6871
to_patch = mod
6972
name = "list"
73+
list_name = "auto:list"
7074
else:
7175
name, to_patch = get_patches(mod, verbose)
76+
list_name = f"auto:{mod.__name__}"
7277

7378
res = {}
7479
for cls in to_patch:
@@ -81,6 +86,8 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
8186
if verbose:
8287
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
8388
setattr(original, f.__name__, cls["patch"])
89+
if patch_details:
90+
patch_details.append(list_name, original, f)
8491
continue
8592

8693
original = cls._PATCHED_CLASS_
@@ -90,6 +97,11 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
9097

9198
keep = {n: getattr(original, n, None) for n in methods}
9299
for n in methods:
100+
if patch_details:
101+
if hasattr(original, n):
102+
patch_details.append(list_name, getattr(original, n), getattr(cls, n))
103+
else:
104+
patch_details.append(list_name, f"{original.__name__}{n}", getattr(cls, n))
93105
setattr(original, n, getattr(cls, n))
94106
res[cls] = keep
95107

@@ -343,7 +355,7 @@ def torch_export_patches(
343355
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
344356
if patch_details:
345357
patch_details.append(
346-
("sympy", f_sympy_name, sympy.core.numbers.IntegerConstant.name)
358+
"sympy", f_sympy_name, sympy.core.numbers.IntegerConstant.name
347359
)
348360

349361
###############
@@ -386,14 +398,14 @@ def torch_export_patches(
386398
f_infer_size = torch._subclasses.fake_impls.infer_size
387399
torch._subclasses.fake_impls.infer_size = patched_infer_size
388400
if patch_details:
389-
patch_details.append(("torch", f_infer_size, patched_infer_size))
401+
patch_details.append("torch", f_infer_size, patched_infer_size)
390402

391403
# torch._refs._broadcast_shapes
392404
f__broadcast_shapes = torch._refs._broadcast_shapes
393405
torch._refs._broadcast_shapes = patched__broadcast_shapes
394406
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
395407
if patch_details:
396-
patch_details.append(("torch", f__broadcast_shapes, patched__broadcast_shapes))
408+
patch_details.append("torch", f__broadcast_shapes, patched__broadcast_shapes)
397409

398410
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
399411
f___constrain_user_specified_dimhint_range = (
@@ -404,11 +416,9 @@ def torch_export_patches(
404416
)
405417
if patch_details:
406418
patch_details.append(
407-
(
408-
"torch",
409-
f___constrain_user_specified_dimhint_range,
410-
patched__constrain_user_specified_dimhint_range,
411-
)
419+
"torch",
420+
f___constrain_user_specified_dimhint_range,
421+
patched__constrain_user_specified_dimhint_range,
412422
)
413423

414424
# torch._prims._broadcast_in_dim_meta
@@ -422,20 +432,20 @@ def torch_export_patches(
422432
torch._prims._broadcast_in_dim_meta = _patched_dim_f
423433
torch._prims.broadcast_in_dim = _patched_dim_f
424434
if patch_details:
425-
patch_details.append(("torch", f_broadcast_in_dim, _patched_dim_f))
435+
patch_details.append("torch", f_broadcast_in_dim, _patched_dim_f)
426436

427437
# torch._refs._maybe_broadcast
428438
f__maybe_broadcast = torch._refs._maybe_broadcast
429439
torch._refs._maybe_broadcast = patched__maybe_broadcast
430440
if patch_details:
431-
patch_details.append(("torch", f__maybe_broadcast, patched__maybe_broadcast))
441+
patch_details.append("torch", f__maybe_broadcast, patched__maybe_broadcast)
432442

433443
# ShapeEnv
434444
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
435445
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
436446
if patch_details:
437447
patch_details.append(
438-
("torch", f_shape_env__evaluate_expr, patched_ShapeEnv._evaluate_expr)
448+
"torch", f_shape_env__evaluate_expr, patched_ShapeEnv._evaluate_expr
439449
)
440450

441451
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
@@ -470,7 +480,7 @@ def torch_export_patches(
470480
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
471481
if patch_details:
472482
patch_details.append(
473-
("torch", f_shape_env__set_replacement, patched_ShapeEnv._set_replacement)
483+
"torch", f_shape_env__set_replacement, patched_ShapeEnv._set_replacement
474484
)
475485

476486
if verbose:
@@ -479,7 +489,7 @@ def torch_export_patches(
479489
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
480490
if patch_details:
481491
patch_details.append(
482-
("torch", f_shape_env__log_guard, patched_ShapeEnv._log_guard)
492+
"torch", f_shape_env__log_guard, patched_ShapeEnv._log_guard
483493
)
484494

485495
if stop_if_static > 1:
@@ -489,7 +499,7 @@ def torch_export_patches(
489499
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
490500
if patch_details:
491501
patch_details.append(
492-
("torch", f_shape_env__check_frozen, ShapeEnv._check_frozen)
502+
"torch", f_shape_env__check_frozen, ShapeEnv._check_frozen
493503
)
494504

495505
####################
@@ -537,11 +547,9 @@ def torch_export_patches(
537547
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
538548
if patch_details:
539549
patch_details.append(
540-
(
541-
"transformers",
542-
f_transformers__vmap_for_bhqkv,
543-
patch_transformers_list.patched__vmap_for_bhqkv,
544-
)
550+
"transformers",
551+
f_transformers__vmap_for_bhqkv,
552+
patch_transformers_list.patched__vmap_for_bhqkv,
545553
)
546554

547555
if verbose:
@@ -555,11 +563,9 @@ def torch_export_patches(
555563
)
556564
if patch_details:
557565
patch_details.append(
558-
(
559-
"transformers",
560-
f_transformers_sdpa_mask_recent_torch,
561-
patch_transformers_list.patched_sdpa_mask_recent_torch,
562-
)
566+
"transformers",
567+
f_transformers_sdpa_mask_recent_torch,
568+
patch_transformers_list.patched_sdpa_mask_recent_torch,
563569
)
564570
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
565571
if verbose:
@@ -573,11 +579,9 @@ def torch_export_patches(
573579
)
574580
if patch_details:
575581
patch_details.append(
576-
(
577-
"transformers",
578-
f_transformers_sdpa_mask,
579-
patch_transformers_list.patched_sdpa_mask_recent_torch,
580-
)
582+
"transformers",
583+
f_transformers_sdpa_mask,
584+
patch_transformers_list.patched_sdpa_mask_recent_torch,
581585
)
582586
else:
583587
f_transformers_sdpa_mask = None
@@ -596,11 +600,9 @@ def torch_export_patches(
596600
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
597601
if patch_details:
598602
patch_details.append(
599-
(
600-
"transformers",
601-
f_transformers_eager_mask,
602-
patch_transformers_list.patched_eager_mask,
603-
)
603+
"transformers",
604+
f_transformers_eager_mask,
605+
patch_transformers_list.patched_eager_mask,
604606
)
605607
if (
606608
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
@@ -662,11 +664,9 @@ def torch_export_patches(
662664
)
663665
if patch_details:
664666
patch_details.append(
665-
(
666-
"transformers",
667-
f_sdpa_attention_forward,
668-
patch_transformers_list.patched_sdpa_attention_forward,
669-
)
667+
"transformers",
668+
f_sdpa_attention_forward,
669+
patch_transformers_list.patched_sdpa_attention_forward,
670670
)
671671

672672
if custom_patches:
Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,29 @@
1-
from typing import Callable
1+
import difflib
2+
from typing import Any, Dict, Callable, List, Optional
3+
4+
5+
def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
6+
"""
7+
Creates a diff between two codes.
8+
9+
:param code1: first code
10+
:param code2: second code
11+
:param output: if not empty, stores the output in this file
12+
:return: diff
13+
"""
14+
text = "\n".join(
15+
difflib.unified_diff(
16+
code1.strip().splitlines(),
17+
code2.strip().splitlines(),
18+
fromfile="original",
19+
tofile="rewritten",
20+
lineterm="",
21+
)
22+
)
23+
if output:
24+
with open(output, "w") as f:
25+
f.write(text)
26+
return text
227

328

429
class PatchDetails:
@@ -9,7 +34,17 @@ class PatchDetails:
934
"""
1035

1136
def __init__(self):
12-
self.rewritten = []
37+
self.patched = []
1338

1439
def append(self, family: str, rewritten: Callable, patched: Callable):
15-
self.rewritten.append((family, rewritten, patched))
40+
self.patched.append((family, rewritten, patched))
41+
42+
@property
43+
def n_patches(self) -> int:
44+
"Returns the number of stored patches."
45+
# Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
46+
return len(self.patched)
47+
48+
def data(self) -> List[Dict[str, Any]]:
49+
"""Returns the data for a dataframe."""
50+
return [dict(zip(["type", "patched", "patch"], v)) for v in self.patched]

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import ast
22
import copy
33
import contextlib
4-
import difflib
54
import inspect
65
import os
76
import types
87
import textwrap
98
import sys
109
from typing import Callable, Dict, List, Set, Optional, Tuple, Union
1110
from .patch_module_helper import code_needing_rewriting
11+
from .patch_details import PatchDetails, make_diff_code
1212

1313
NODE_TYPES = tuple(
1414
getattr(ast, k)
@@ -881,6 +881,7 @@ def torch_export_rewrite(
881881
] = None,
882882
dump_rewriting: Optional[str] = None,
883883
verbose: int = 0,
884+
patch_details: Optional[PatchDetails] = None,
884885
):
885886
"""
886887
Automatically rewrite the methods given in `rewrite` to export
@@ -897,6 +898,8 @@ def torch_export_rewrite(
897898
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
898899
``verbose=1`` shows the rewritten function,
899900
``verbose=2`` shows the rewritten code as well
901+
:param patch_details: to store any applied patch and get a better understanding
902+
of the applied modifications
900903
901904
Example:
902905
@@ -1030,7 +1033,9 @@ def forward(self, x, y):
10301033
rcode = _clean_code(rewr.code)
10311034
f.write(rcode)
10321035
diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
1033-
make_diff(code, rcode, diff)
1036+
make_diff_code(code, rcode, diff)
1037+
if patch_details:
1038+
patch_details.append("rewrite", getattr(cls, name), rewr.func)
10341039
setattr(cls, name, rewr.func)
10351040

10361041
try:
@@ -1048,27 +1053,3 @@ def _clean_code(code: str) -> str:
10481053
except ImportError:
10491054
return code
10501055
return black.format_str(code, mode=black.FileMode(line_length=98))
1051-
1052-
1053-
def make_diff(code1: str, code2: str, output: Optional[str] = None) -> str:
1054-
"""
1055-
Creates a diff between two codes.
1056-
1057-
:param code1: first code
1058-
:param code2: second code
1059-
:param output: if not empty, stores the output in this file
1060-
:return: diff
1061-
"""
1062-
text = "\n".join(
1063-
difflib.unified_diff(
1064-
code1.strip().splitlines(),
1065-
code2.strip().splitlines(),
1066-
fromfile="original",
1067-
tofile="rewritten",
1068-
lineterm="",
1069-
)
1070-
)
1071-
if output:
1072-
with open(output, "w") as f:
1073-
f.write(text)
1074-
return text

0 commit comments

Comments
 (0)